Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
R
re6stnet
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
2
Issues
2
List
Boards
Labels
Milestones
Merge Requests
4
Merge Requests
4
Analytics
Analytics
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Commits
Issue Boards
Open sidebar
nexedi
re6stnet
Commits
e80ca878
Commit
e80ca878
authored
Oct 07, 2024
by
Thomas Gambier
🚴🏼
Browse files
Options
Browse Files
Download
Plain Diff
Port re6stnet to python3
See merge request
!46
parents
f2fd7247
227d63d4
Changes
36
Hide whitespace changes
Inline
Side-by-side
Showing
36 changed files
with
893 additions
and
763 deletions
+893
-763
.gitignore
.gitignore
+8
-0
README.rst
README.rst
+1
-1
demo/README.rst
demo/README.rst
+25
-0
demo/demo
demo/demo
+171
-124
demo/fixnemu.py
demo/fixnemu.py
+0
-89
demo/py
demo/py
+3
-2
demo/test_hmac.py
demo/test_hmac.py
+1
-1
draft/re6st-cn
draft/re6st-cn
+1
-1
re6st/cache.py
re6st/cache.py
+36
-40
re6st/cli/conf.py
re6st/cli/conf.py
+29
-27
re6st/cli/node.py
re6st/cli/node.py
+9
-15
re6st/cli/registry.py
re6st/cli/registry.py
+9
-9
re6st/ctl.py
re6st/ctl.py
+28
-38
re6st/debug.py
re6st/debug.py
+11
-11
re6st/multicast.py
re6st/multicast.py
+2
-2
re6st/ovpn-client
re6st/ovpn-client
+3
-2
re6st/ovpn-server
re6st/ovpn-server
+4
-3
re6st/plib.py
re6st/plib.py
+18
-10
re6st/registry.py
re6st/registry.py
+156
-111
re6st/tests/__init__.py
re6st/tests/__init__.py
+1
-1
re6st/tests/test_end2end/test_registry_client.py
re6st/tests/test_end2end/test_registry_client.py
+10
-9
re6st/tests/test_network/network_build.py
re6st/tests/test_network/network_build.py
+7
-11
re6st/tests/test_network/re6st_wrap.py
re6st/tests/test_network/re6st_wrap.py
+26
-16
re6st/tests/test_network/test_net.py
re6st/tests/test_network/test_net.py
+43
-27
re6st/tests/test_unit/test_conf.py
re6st/tests/test_unit/test_conf.py
+4
-4
re6st/tests/test_unit/test_registry.py
re6st/tests/test_unit/test_registry.py
+63
-30
re6st/tests/test_unit/test_registry_client.py
re6st/tests/test_unit/test_registry_client.py
+10
-9
re6st/tests/test_unit/test_tunnel/test_base_tunnel_manager.py
...t/tests/test_unit/test_tunnel/test_base_tunnel_manager.py
+8
-8
re6st/tests/test_unit/test_tunnel/test_multi_gateway_manager.py
...tests/test_unit/test_tunnel/test_multi_gateway_manager.py
+1
-1
re6st/tests/tools.py
re6st/tests/tools.py
+12
-15
re6st/tunnel.py
re6st/tunnel.py
+64
-49
re6st/upnpigd.py
re6st/upnpigd.py
+8
-8
re6st/utils.py
re6st/utils.py
+29
-32
re6st/version.py
re6st/version.py
+1
-1
re6st/x509.py
re6st/x509.py
+80
-49
setup.py
setup.py
+11
-7
No files found.
.gitignore
View file @
e80ca878
...
@@ -5,3 +5,11 @@
...
@@ -5,3 +5,11 @@
/build/
/build/
/dist/
/dist/
/re6stnet.egg-info/
/re6stnet.egg-info/
*.log
*.pid
*.db
*.state
*.crt
*.pem
demo/mbox
*.sock
README.rst
View file @
e80ca878
...
@@ -52,7 +52,7 @@ easily scalable to tens of thousand of nodes.
...
@@ -52,7 +52,7 @@ easily scalable to tens of thousand of nodes.
Requirements
Requirements
============
============
- Python
2.7
- Python
3.11
- OpenSSL binary and development libraries
- OpenSSL binary and development libraries
- OpenVPN 2.4.*
- OpenVPN 2.4.*
- Babel_ (with Nexedi patches)
- Babel_ (with Nexedi patches)
...
...
demo/README.rst
0 → 100644
View file @
e80ca878
Demo
====
Usage
-----
To run the demo, make sure all the dependencies are installed
and run ``./demo 8000`` (or any port).
Troubleshooting
---------------
If the demo crashes and fails to clean up its resources properly,
run the following commands::
for b in $(sudo ip l | grep -Po 'NETNS\w\w[\d\-a-f]+'); do sudo ip l del $b; done
pkill screen
killall python
killall python3
find . -name '*.crt' -delete; find . -name '*.db' -delete; find . -name '*.log' -delete
.. warning::
This will kill all Python processes. These commands assume you're running
the demo on a dedicated machine with nothing else on it.
demo/demo
View file @
e80ca878
#!/usr/bin/
python2
#!/usr/bin/
env python3
import argparse, math, nemu, os, re, signal
import argparse, math, nemu, os, re, s
hlex, s
ignal
import socket, sqlite3, subprocess, sys, time, weakref
import socket, sqlite3, subprocess, sys, time, weakref
from collections import defaultdict
from collections import defaultdict
from contextlib import contextmanager
from contextlib import contextmanager
from threading import Thread
from threading import Thread
from typing import Optional
IPTABLES = 'iptables'
IPTABLES = 'iptables'
SCREEN = 'screen'
SCREEN = 'screen'
VERBOSE = 4
VERBOSE = 4
...
@@ -14,9 +16,9 @@ REGISTRY2_SERIAL = '0x120010db80043'
...
@@ -14,9 +16,9 @@ REGISTRY2_SERIAL = '0x120010db80043'
CA_DAYS = 1000
CA_DAYS = 1000
# Quick check to avoid wasting time if there is an error.
# Quick check to avoid wasting time if there is an error.
with open(os.devnull, "wb") as f
:
for x in 're6stnet', 're6st-conf', 're6st-registry'
:
for x in 're6stnet', 're6st-conf', 're6st-registry':
subprocess.check_call(('./py', x, '--help'), stdout=subprocess.DEVNULL)
subprocess.check_call(('./py', x, '--help'), stdout=f)
#
#
# Underlying network:
# Underlying network:
#
#
...
@@ -55,14 +57,14 @@ def _add_interface(node, iface):
...
@@ -55,14 +57,14 @@ def _add_interface(node, iface):
nemu.Node._add_interface = _add_interface
nemu.Node._add_interface = _add_interface
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser()
parser.add_argument('port', type
=
int,
parser.add_argument('port', type
=
int,
help
=
'port used to display tunnels')
help
=
'port used to display tunnels')
parser.add_argument('-d', '--duration', type
=
int,
parser.add_argument('-d', '--duration', type
=
int,
help
=
'time of the demo execution in seconds')
help
=
'time of the demo execution in seconds')
parser.add_argument('-p', '--ping', action
=
'store_true',
parser.add_argument('-p', '--ping', action
=
'store_true',
help
=
'execute ping utility')
help
=
'execute ping utility')
parser.add_argument('-m', '--hmac', action
=
'store_true',
parser.add_argument('-m', '--hmac', action
=
'store_true',
help
=
'execute HMAC test')
help
=
'execute HMAC test')
args = parser.parse_args()
args = parser.parse_args()
def handler(signum, frame):
def handler(signum, frame):
...
@@ -72,33 +74,55 @@ if args.duration:
...
@@ -72,33 +74,55 @@ if args.duration:
signal.signal(signal.SIGALRM, handler)
signal.signal(signal.SIGALRM, handler)
signal.alarm(args.duration)
signal.alarm(args.duration)
execfile("fixnemu.py")
class Re6stNode(nemu.Node):
name: str
short: str
re6st_cmdline: Optional[list[str]]
def __init__(self, name, short):
super().__init__()
self.name = name
self.short = short
self.Popen(('sysctl', '-q',
'net.ipv4.icmp_echo_ignore_broadcasts=0')).wait()
self._screen = self.Popen((SCREEN, '-DmS', name))
self.re6st_cmdline = None
def screen(self, command: list[str]):
runner_cmd = ('set -- %s; "\\$@"; echo "\\$@"; exec $SHELL' %
' '.join(map(shlex.quote, command)))
inner_cmd = [
'screen', 'sh', '-c', runner_cmd
]
cmd = [
SCREEN, '-r', self.name, '-X', 'eval', shlex.join(inner_cmd)
]
return subprocess.call(cmd)
# create nodes
# create nodes
for name in """internet=I registry=R
internet = Re6stNode('internet', 'I')
gateway1=g1 machine1=1 machine2=2
registry = Re6stNode('registry', 'R')
gateway2=g2 machine3=3 machine4=4 machine5=5
gateway1 = Re6stNode('gateway1', 'g1')
machine6=6 machine7=7 machine8=8 machine9=9
machine1 = Re6stNode('machine1', '1')
registry2=R2 machine10=10
machine2 = Re6stNode('machine2', '2')
""".split():
gateway2 = Re6stNode('gateway2', 'g2')
name, short = name.split('=')
machine3 = Re6stNode('machine3', '3')
globals()[name] = node = nemu.Node()
machine4 = Re6stNode('machine4', '4')
node.name = name
machine5 = Re6stNode('machine5', '5')
node.short = short
machine6 = Re6stNode('machine6', '6')
node.Popen(('sysctl', '-q',
machine7 = Re6stNode('machine7', '7')
'net.ipv4.icmp_echo_ignore_broadcasts=0')).wait()
machine8 = Re6stNode('machine8', '8')
node._screen = node.Popen((SCREEN, '-DmS', name))
machine9 = Re6stNode('machine9', '9')
node.screen = (lambda name: lambda *cmd:
registry2 = Re6stNode('registry2', 'R2')
subprocess.call([SCREEN, '-r', name, '-X', 'eval'] + map(
machine10 = Re6stNode('machine10', '10')
"""screen sh -c 'set %s; "\$@"; echo "\$@"; exec $SHELL'"""
.__mod__, cmd)))(name)
# create switch
# create switch
switch1 = nemu.Switch()
switch1 = nemu.Switch()
switch2 = nemu.Switch()
switch2 = nemu.Switch()
switch3 = nemu.Switch()
switch3 = nemu.Switch()
#create interfaces
#
create interfaces
re_if_0, in_if_0 = nemu.P2PInterface.create_pair(registry, internet)
re_if_0, in_if_0 = nemu.P2PInterface.create_pair(registry, internet)
in_if_1, g1_if_0 = nemu.P2PInterface.create_pair(internet, gateway1)
in_if_1, g1_if_0 = nemu.P2PInterface.create_pair(internet, gateway1)
in_if_2, g2_if_0 = nemu.P2PInterface.create_pair(internet, gateway2)
in_if_2, g2_if_0 = nemu.P2PInterface.create_pair(internet, gateway2)
...
@@ -205,19 +229,19 @@ for m in machine6, machine7, machine8:
...
@@ -205,19 +229,19 @@ for m in machine6, machine7, machine8:
# Test connectivity first. Run process, hide output and check
# Test connectivity first. Run process, hide output and check
# return code
# return code
null = file(os.devnull, "r+")
for ip in '10.1.1.2', '10.1.1.3', '10.2.1.2', '10.2.1.3':
for ip in '10.1.1.2', '10.1.1.3', '10.2.1.2', '10.2.1.3':
if machine1.Popen(('ping', '-c1', ip), stdout=
null
).wait():
if machine1.Popen(('ping', '-c1', ip), stdout=
subprocess.DEVNULL
).wait():
print
'Failed to ping %s' % ip
print
('Failed to ping', ip)
break
break
else:
else:
print "Connectivity IPv4 OK!"
print("Connectivity IPv4 OK!")
nodes: list[Re6stNode] = []
gateway1.screen(['miniupnpd', '-d', '-f', 'miniupnpd.conf', '-P',
'miniupnpd.pid', '-a', g1_if_1.name, '-i', g1_if_0_name])
nodes = []
gateway1.screen('miniupnpd -d -f miniupnpd.conf -P miniupnpd.pid'
' -a %s -i %s' % (g1_if_1.name, g1_if_0_name))
@contextmanager
@contextmanager
def new_network(registry
, reg_addr, serial, ca
):
def new_network(registry
: Re6stNode, reg_addr: str, serial: str, ca: str
):
from OpenSSL import crypto
from OpenSSL import crypto
import hashlib, sqlite3
import hashlib, sqlite3
os.path.exists(ca) or subprocess.check_call(
os.path.exists(ca) or subprocess.check_call(
...
@@ -225,16 +249,18 @@ def new_network(registry, reg_addr, serial, ca):
...
@@ -225,16 +249,18 @@ def new_network(registry, reg_addr, serial, ca):
" -subj /CN=re6st.example.com/emailAddress=re6st@example.com"
" -subj /CN=re6st.example.com/emailAddress=re6st@example.com"
" -set_serial %s -days %u"
" -set_serial %s -days %u"
% (registry.name, ca, serial, CA_DAYS), shell=True)
% (registry.name, ca, serial, CA_DAYS), shell=True)
with open(ca) as f:
with open(ca
, "rb"
) as f:
cert = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
cert = crypto.load_certificate(crypto.FILETYPE_PEM, f.read())
fingerprint = "sha256:" + hashlib.sha256(
fingerprint = "sha256:" + hashlib.sha256(
crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)).hexdigest()
crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)).hexdigest()
db_path = "%s/registry.db" % registry.name
db_path = "%s/registry.db" % registry.name
registry.screen("./py re6st-registry @%s/re6st-registry.conf"
registry.screen([
" --db %s --mailhost %s -v%u"
sys.executable, './py', 're6st-registry',
% (registry.name, db_path, os.path.abspath('mbox'), VERBOSE))
'@%s/re6st-registry.conf' % registry.name, '--db', db_path,
'--mailhost', os.path.abspath('mbox'), '-v%u' % VERBOSE,
])
registry_url = 'http://%s/' % reg_addr
registry_url = 'http://%s/' % reg_addr
registry.Popen((
'python'
, '-c', """if 1:
registry.Popen((
sys.executable
, '-c', """if 1:
import socket, time
import socket, time
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
while True:
while True:
...
@@ -245,16 +271,21 @@ def new_network(registry, reg_addr, serial, ca):
...
@@ -245,16 +271,21 @@ def new_network(registry, reg_addr, serial, ca):
time.sleep(.1)
time.sleep(.1)
""")).wait()
""")).wait()
db = sqlite3.connect(db_path, isolation_level=None)
db = sqlite3.connect(db_path, isolation_level=None)
def new_node(node, folder, args='', prefix_len=None, registry=registry_url):
def new_node(node: Re6stNode, folder: str, args: list[str]=[],
prefix_len: Optional[int] = None, registry=registry_url):
nodes.append(node)
nodes.append(node)
if not os.path.exists(folder + '/cert.crt'):
if not os.path.exists(folder + '/cert.crt'):
dh_path = folder + '/dh2048.pem'
dh_path = folder + '/dh2048.pem'
if not os.path.exists(dh_path):
if not os.path.exists(dh_path):
os.symlink('../dh2048.pem', dh_path)
os.symlink('../dh2048.pem', dh_path)
email = node.name + '@example.com'
email = node.name + '@example.com'
p = node.Popen(('../py', 're6st-conf', '--registry', registry,
p = node.Popen((
'--email', email, '--fingerprint', fingerprint),
sys.executable, '../py', 're6st-conf',
stdin=subprocess.PIPE, cwd=folder)
'--registry', registry,
'--email', email,
'--fingerprint', fingerprint,
), stdin=subprocess.PIPE, cwd=folder)
token = None
token = None
while not token:
while not token:
time.sleep(.1)
time.sleep(.1)
...
@@ -266,27 +297,30 @@ def new_network(registry, reg_addr, serial, ca):
...
@@ -266,27 +297,30 @@ def new_network(registry, reg_addr, serial, ca):
p.communicate(str(token[0]))
p.communicate(str(token[0]))
os.remove(dh_path)
os.remove(dh_path)
os.remove(folder + '/ca.crt')
os.remove(folder + '/ca.crt')
node.re6st_cmdline = (
node.re6st_cmdline = [
'./py re6stnet @%s/re6stnet.conf -v%u --registry %s'
sys.executable, './py', 're6stnet', '@%s/re6stnet.conf' % folder,
' --console %s/run/console.sock %s'
'-v%u' % VERBOSE, '--registry', registry, '--console',
) % (folder, VERBOSE, registry, folder, args)
'%s/run/console.sock' % folder, *args,
]
node.screen(node.re6st_cmdline)
node.screen(node.re6st_cmdline)
new_node(registry, registry.name, '--ip ' + reg_addr, registry='http://localhost/')
new_node(registry, registry.name, ['--ip', reg_addr],
registry='http://localhost/')
yield new_node
yield new_node
db.close()
db.close()
with new_network(registry, REGISTRY, REGISTRY_SERIAL, 'ca.crt') as new_node:
with new_network(registry, REGISTRY, REGISTRY_SERIAL, 'ca.crt') as new_node:
new_node(machine1, 'm1',
'-I%s' % m1_if_0.name
)
new_node(machine1, 'm1',
['-I%s' % m1_if_0.name]
)
new_node(machine2, 'm2',
'--remote-gateway 10.1.1.1'
, prefix_len=77)
new_node(machine2, 'm2',
['--remote-gateway', '10.1.1.1']
, prefix_len=77)
new_node(machine3, 'm3',
'-i%s' % m3_if_0.name
)
new_node(machine3, 'm3',
['-i%s' % m3_if_0.name]
)
new_node(machine4, 'm4',
'-i%s' % m4_if_0.name
)
new_node(machine4, 'm4',
['-i%s' % m4_if_0.name]
)
new_node(machine5, 'm5',
'-i%s' % m5_if_0.name
)
new_node(machine5, 'm5',
['-i%s' % m5_if_0.name]
)
new_node(machine6, 'm6',
'-I%s' % m6_if_1.name
)
new_node(machine6, 'm6',
['-I%s' % m6_if_1.name]
)
new_node(machine7, 'm7')
new_node(machine7, 'm7')
new_node(machine8, 'm8')
new_node(machine8, 'm8')
with new_network(registry2, REGISTRY2, REGISTRY2_SERIAL, 'ca2.crt') as new_node:
with new_network(registry2, REGISTRY2, REGISTRY2_SERIAL, 'ca2.crt') as new_node:
new_node(machine10, 'm10',
'-i%s' % m10_if_0.name
)
new_node(machine10, 'm10',
['-i%s' % m10_if_0.name]
)
if args.ping:
if args.ping:
for j, machine in enumerate(nodes):
for j, machine in enumerate(nodes):
...
@@ -297,10 +331,10 @@ if args.ping:
...
@@ -297,10 +331,10 @@ if args.ping:
'2001:db8:43:1::1' if i == 10 else
'2001:db8:43:1::1' if i == 10 else
# Only 1 address for machine2 because prefix_len = 80,+48 = 128
# Only 1 address for machine2 because prefix_len = 80,+48 = 128
'2001:db8:42:%s::1' % i
'2001:db8:42:%s::1' % i
for i in
x
range(11)
for i in range(11)
if i != j]
if i != j]
name = machine.name if machine.short[0] == 'R' else 'm' + machine.short
name = machine.name if machine.short[0] == 'R' else 'm' + machine.short
machine.screen(
'python ping.py {} {}'.format(name, ' '.join(ips))
)
machine.screen(
['python', 'ping.py', name] + ips
)
class testHMAC(Thread):
class testHMAC(Thread):
...
@@ -308,36 +342,36 @@ class testHMAC(Thread):
...
@@ -308,36 +342,36 @@ class testHMAC(Thread):
updateHMAC = ('python', '-c', "import urllib, sys; sys.exit("
updateHMAC = ('python', '-c', "import urllib, sys; sys.exit("
"204 != urllib.urlopen('http://127.0.0.1/updateHMAC').code)")
"204 != urllib.urlopen('http://127.0.0.1/updateHMAC').code)")
reg1_db = sqlite3.connect('registry/registry.db', isolation_level=None,
reg1_db = sqlite3.connect('registry/registry.db', isolation_level=None,
check_same_thread=False)
check_same_thread=False)
reg2_db = sqlite3.connect('registry2/registry.db', isolation_level=None,
reg2_db = sqlite3.connect('registry2/registry.db', isolation_level=None,
check_same_thread=False)
check_same_thread=False)
reg1_db.text_factory = reg2_db.text_factory = str
reg1_db.text_factory = reg2_db.text_factory = str
m_net1 = 'registry', 'm1', 'm2', 'm3', 'm4', 'm5', 'm6', 'm7', 'm8'
m_net1 = 'registry', 'm1', 'm2', 'm3', 'm4', 'm5', 'm6', 'm7', 'm8'
m_net2 = 'registry2', 'm10'
m_net2 = 'registry2', 'm10'
print
'Testing HMAC, letting the time to machines to create tunnels...'
print
('Testing HMAC, letting the time to machines to create tunnels...')
time.sleep(45)
time.sleep(45)
print
'Check that the initial HMAC config is deployed on network 1'
print
('Check that the initial HMAC config is deployed on network 1')
test_hmac.checkHMAC(reg1_db, m_net1)
test_hmac.checkHMAC(reg1_db, m_net1)
print
'Test that a HMAC update works with nodes that are up'
print
('Test that a HMAC update works with nodes that are up')
registry.backticks_raise(updateHMAC)
registry.backticks_raise(updateHMAC)
print
'Updated HMAC (config = hmac0
&
hmac1), waiting...'
print
('Updated HMAC (config = hmac0
&
hmac1), waiting...')
time.sleep(60)
time.sleep(60)
print
'Checking HMAC on machines connected to registry 1...'
print
('Checking HMAC on machines connected to registry 1...')
test_hmac.checkHMAC(reg1_db, m_net1)
test_hmac.checkHMAC(reg1_db, m_net1)
print
('Test that machines can update upon reboot ' +
print
('Test that machines can update upon reboot '
'when they were off during a HMAC update.')
'when they were off during a HMAC update.')
test_hmac.killRe6st(machine1)
test_hmac.killRe6st(machine1)
print
'Re6st on machine 1 is stopped'
print
('Re6st on machine 1 is stopped')
time.sleep(5)
time.sleep(5)
registry.backticks_raise(updateHMAC)
registry.backticks_raise(updateHMAC)
print
'Updated HMAC on registry (config = hmac1
&
hmac2), waiting...'
print
('Updated HMAC on registry (config = hmac1
&
hmac2), waiting...')
time.sleep(60)
time.sleep(60)
machine1.screen(machine1.re6st_cmdline)
machine1.screen(machine1.re6st_cmdline)
print
'Started re6st on machine 1, waiting for it to get new conf'
print
('Started re6st on machine 1, waiting for it to get new conf')
time.sleep(60)
time.sleep(60)
print
'Checking HMAC on machines connected to registry 1...'
print
('Checking HMAC on machines connected to registry 1...')
test_hmac.checkHMAC(reg1_db, m_net1)
test_hmac.checkHMAC(reg1_db, m_net1)
print
'Testing of HMAC done!'
print
('Testing of HMAC done!')
# TODO: missing last step
# TODO: missing last step
reg1_db.close()
reg1_db.close()
reg2_db.close()
reg2_db.close()
...
@@ -349,8 +383,9 @@ if args.hmac:
...
@@ -349,8 +383,9 @@ if args.hmac:
t.start()
t.start()
del t
del t
_ll = {}
_ll: dict[str, tuple[Re6stNode, bool]] = {}
def node_by_ll(addr):
def node_by_ll(addr: str) -> tuple[Re6stNode, bool]:
try:
try:
return _ll[addr]
return _ll[addr]
except KeyError:
except KeyError:
...
@@ -368,27 +403,28 @@ def node_by_ll(addr):
...
@@ -368,27 +403,28 @@ def node_by_ll(addr):
if a.startswith('10.42.'):
if a.startswith('10.42.'):
assert not p % 8
assert not p % 8
_ll[socket.inet_ntoa(socket.inet_aton(
_ll[socket.inet_ntoa(socket.inet_aton(
a)[:p/
8].ljust(4,
'\0'))] = n, t
a)[:p/
/8].ljust(4, b
'\0'))] = n, t
elif a.startswith('2001:db8:'):
elif a.startswith('2001:db8:'):
assert not p % 8
assert not p % 8
a = socket.inet_ntop(socket.AF_INET6,
a = socket.inet_ntop(socket.AF_INET6,
socket.inet_pton(socket.AF_INET6,
socket.inet_pton(socket.AF_INET6,
a)[:p
/8].ljust(16,
'\0'))
a)[:p
// 8].ljust(16, b
'\0'))
elif not a.startswith('fe80::'):
elif not a.startswith('fe80::'):
continue
continue
_ll[a] = n, t
_ll[a] = n, t
return _ll[addr]
return _ll[addr]
def route_svg(ipv4, z = 4, default = type('', (), {'short': None})):
graph = {}
def route_svg(ipv4, z=4):
graph: dict[Re6stNode, dict[tuple[Re6stNode, bool], list[Re6stNode]]] = {}
for n in nodes:
for n in nodes:
g = graph[n] = defaultdict(list)
g = graph[n] = defaultdict(list)
for r in n.get_routes():
for r in n.get_routes():
if (r.prefix and r.prefix.startswith('10.42.') if ipv4 else
if (r.prefix and r.prefix.startswith('10.42.') if ipv4 else
r.prefix is None or r.prefix.startswith('2001:db8:')):
r.prefix is None or r.prefix.startswith('2001:db8:')):
try:
try:
g[node_by_ll(r.nexthop)].append(
if r.prefix:
node_by_ll(r.prefix)[0] if r.prefix else default
)
g[node_by_ll(r.nexthop)].append(node_by_ll(r.prefix)[0]
)
except KeyError:
except KeyError:
pass
pass
gv = ["digraph { splines = true; edge[color=grey, labelangle=0];"]
gv = ["digraph { splines = true; edge[color=grey, labelangle=0];"]
...
@@ -399,37 +435,37 @@ def route_svg(ipv4, z = 4, default = type('', (), {'short': None})):
...
@@ -399,37 +435,37 @@ def route_svg(ipv4, z = 4, default = type('', (), {'short': None})):
gv.append('%s[pos="%s,%s!"];'
gv.append('%s[pos="%s,%s!"];'
% (n.name, z * math.cos(a * i), z * math.sin(a * i)))
% (n.name, z * math.cos(a * i), z * math.sin(a * i)))
l = []
l = []
for p, r in graph[n].ite
rite
ms():
for p, r in graph[n].items():
j = abs(nodes.index(p[0]) - i)
j = abs(nodes.index(p[0]) - i)
l.append((min(j, N - j), p, r))
l.append((min(j, N - j), p, r))
for j, (
l, (p, t), r) in enumerate(sorted(l
)):
for j, (
_, (p2, t), r) in enumerate(sorted(l, key=lambda x: x[0]
)):
l = []
l
2
= []
arrowhead = 'none'
arrowhead = 'none'
for r
in sorted(r.short for r
in r):
for r
2 in sorted(r2.short or '' for r2
in r):
if r:
if r
2
:
if r
== p
.short:
if r
2 == p2
.short:
r
= '
<font
color=
"grey"
>
%s
</font>
' % r
r
2 = '
<font
color=
"grey"
>
%s
</font>
' % r2
l
.append(r
)
l
2.append(r2
)
else:
else:
arrowhead = 'dot'
arrowhead = 'dot'
if (n.name, p.name) in edges:
if (n.name, p
2
.name) in edges:
r = 'penwidth=0'
r
3
= 'penwidth=0'
else:
else:
edges.add((p.name, n.name))
edges.add((p
2
.name, n.name))
r = 'style=solid' if t else 'style=dashed'
r
3
= 'style=solid' if t else 'style=dashed'
gv.append(
gv.append(
'%s -> %s [labeldistance=%u, headlabel=
<
%
s
>
, arrowhead=%s, %s];'
'%s -> %s [labeldistance=%u, headlabel=
<
%
s
>
, arrowhead=%s, %s];'
% (p
.name, n.name, 1.5 * math.sqrt(j) + 2, ','.join(l
),
% (p
2.name, n.name, 1.5 * math.sqrt(j) + 2, ','.join(l2
),
arrowhead, r))
arrowhead, r
3
))
gv.append('}\n')
gv.append('}\n')
return subprocess.
Popen(('neato', '-Tsvg'),
return subprocess.
run(
stdin=subprocess.PIPE, stdout=subprocess.PIPE
,
('neato', '-Tsvg'), check=True, text=True, capture_output=True
,
).communicate('\n'.join(gv))[0]
input='\n'.join(gv)).stdout
if args.port:
if args.port:
import
SimpleHTTPServer, SocketS
erver
import
http.server, sockets
erver
class Handler(
SimpleHTTPS
erver.SimpleHTTPRequestHandler):
class Handler(
http.s
erver.SimpleHTTPRequestHandler):
_path_match = re.compile('/(.+)\.(html|svg)$').match
_path_match = re.compile('/(.+)\.(html|svg)$').match
pages = 'ipv6', 'ipv4', 'tunnels'
pages = 'ipv6', 'ipv4', 'tunnels'
...
@@ -439,7 +475,7 @@ if args.port:
...
@@ -439,7 +475,7 @@ if args.port:
try:
try:
name, ext = self._path_match(self.path).groups()
name, ext = self._path_match(self.path).groups()
page = self.pages.index(name)
page = self.pages.index(name)
except
AttributeError, ValueError
:
except
(AttributeError, ValueError)
:
if self.path == '/':
if self.path == '/':
self.send_response(302)
self.send_response(302)
self.send_header('Location', self.pages[0] + '.html')
self.send_header('Location', self.pages[0] + '.html')
...
@@ -450,36 +486,46 @@ if args.port:
...
@@ -450,36 +486,46 @@ if args.port:
if page
<
2:
if page
<
2:
body =
route_svg(page)
body =
route_svg(page)
else:
else:
body =
registry.Popen(('python
',
'
-c
',
r
"""
if
1:
out
,
err =
(registry.Popen(('python3
',
'
-c
',
r
"""
if
1:
import
math
,
json
import
math
,
json
from
re6st
.
registry
import
RegistryClient
from
re6st
.
registry
import
RegistryClient
g =
json.loads(RegistryClient(
topo =
RegistryClient('http://localhost/').topology()
'
http:
//
localhost
/').
topology
())
g =
json.loads(topo)
if
not
g:
print
('
digraph
{
"
empty
topology
"
[shape=
"none"
]
}')
exit
()
r =
set(g.pop('',
()))
r =
set(g.pop('',
()))
a =
set()
a =
set()
for
v
in
g
.
iter
values
()
:
for
v
in
g
.
values
()
:
a
.
update
(
v
)
a
.
update
(
v
)
g
.
update
(
dict
.
fromkeys
(
a
.
difference
(
g
),
()))
g
.
update
(
dict
.
fromkeys
(
a
.
difference
(
g
),
()))
print
'
digraph
{'
print
('
digraph
{')
a =
2
*
math
.
pi
/
len
(
g
)
a =
2
*
math
.
pi
/
len
(
g
)
z =
4
z =
4
m2 =
'%u/80'
%
(
2
<<
64
)
m2 =
'%u/80'
%
(
2
<<
64
)
title =
lambda
n:
'
2
|
80
'
if
n =
=
m2
else
n
title =
lambda
n:
'
2
|
80
'
if
n =
=
m2
else
n
g =
sorted((title(k),
k
in
r
,
v
)
for
k
,
v
in
g
.
ite
rite
ms
())
g =
sorted((title(k),
k
in
r
,
v
)
for
k
,
v
in
g
.
items
())
for
i
,
(
n
,
r
,
v
)
in
enumerate
(
g
)
:
for
i
,
(
n
,
r
,
v
)
in
enumerate
(
g
)
:
print
'"%
s
"
[pos=
"%s,%s!"
%
s
];'
%
(
title
(
n
),
print
(
'"%
s
"
[pos=
"%s,%s!"
%
s
];'
%
(
title
(
n
),
z
*
math
.
cos
(
a
*
i
),
z
*
math
.
sin
(
a
*
i
),
z
*
math
.
cos
(
a
*
i
),
z
*
math
.
sin
(
a
*
i
),
''
if
r
else
',
style=
dashed')
''
if
r
else
',
style=
dashed')
)
for
v
in
v:
for
v
in
v:
print
'"%
s
"
-
>
"%s";' % (n, title(v))
print
('"%
s
"
-
>
"%s";' % (n, title(v)))
print '}'
print('}')
"""), stdout=subprocess.PIPE, cwd="..").communicate()[0]
"""), stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd="..")
if body:
.communicate())
body = subprocess.Popen(('neato', '-Tsvg'),
if err:
stdin=subprocess.PIPE, stdout=subprocess.PIPE,
self.send_error(500, explain='SVG generation failed: '
).communicate(body)[0]
+ err.decode(errors='replace'))
if not body:
return
self.send_error(500)
graph_body = out.decode("utf-8")
try:
body = subprocess.run(
('neato', '-Tsvg'), check=True, text=True,
capture_output=True,
input=graph_body).stdout
except subprocess.CalledProcessError as e:
self.send_error(500, explain='neato failed: ' + e.stderr)
return
return
if ext == 'svg':
if ext == 'svg':
mt = 'image/svg+xml'
mt = 'image/svg+xml'
...
@@ -508,14 +554,15 @@ if args.port:
...
@@ -508,14 +554,15 @@ if args.port:
for i, x in enumerate(self.pages)),
for i, x in enumerate(self.pages)),
body[body.find('
<svg
')
:
])
body[body.find('
<svg
')
:
])
self
.
send_response
(
200
)
self
.
send_response
(
200
)
self
.
send_header
('
Content-Length
',
len
(
body
))
body =
body.encode("utf-8")
self
.
send_header
('
Content-Length
',
str
(
len
(
body
)))
self
.
send_header
('
Content-type
',
mt
+
';
charset=
utf-8')
self
.
send_header
('
Content-type
',
mt
+
';
charset=
utf-8')
self
.
end_headers
()
self
.
end_headers
()
self
.
wfile
.
write
(
body
)
self
.
wfile
.
write
(
body
)
class
TCPServer
(
SocketS
erver
.
TCPServer
)
:
class
TCPServer
(
sockets
erver
.
TCPServer
)
:
allow_reuse_address =
True
allow_reuse_address =
True
TCPServer
(('',
args
.
port
),
Handler
).
serve_forever
()
TCPServer
(('',
args
.
port
),
Handler
).
serve_forever
()
import
pdb
;
pdb
.
set_trace
()
breakpoint
()
demo/fixnemu.py
deleted
100644 → 0
View file @
f2fd7247
# -*- coding: utf-8 -*-
# Copyright 2010, 2011 INRIA
# Copyright 2011 Martín Ferrari <martin.ferrari@gmail.com>
#
# This file is contains patches to Nemu.
#
# Nemu is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License version 2, as published by the Free
# Software Foundation.
#
# Nemu is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
# A PARTICULAR PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# Nemu. If not, see <http://www.gnu.org/licenses/>.
import
re
import
os
from
new
import
function
from
nemu.iproute
import
backticks
,
get_if_data
,
route
,
\
get_addr_data
,
get_all_route_data
,
interface
from
nemu.interface
import
Switch
,
Interface
def
_get_all_route_data
():
ipdata
=
backticks
([
IP_PATH
,
"-o"
,
"route"
,
"list"
])
# "table", "all"
ipdata
+=
backticks
([
IP_PATH
,
"-o"
,
"-f"
,
"inet6"
,
"route"
,
"list"
])
ifdata
=
get_if_data
()[
1
]
ret
=
[]
for
line
in
ipdata
.
split
(
"
\
n
"
):
if
line
==
""
:
continue
# PATCH: parse 'from'
# PATCH: 'dev' is missing on 'unreachable' ipv4 routes
match
=
re
.
match
(
'(?:(unicast|local|broadcast|multicast|throw|'
r'unreachable|prohibit|blackhole|nat) )?(\
S+)(?:
from (\
S+))?
'
r'
(
?
:
via
(
\
S
+
))
?
(
?
:
dev
(
\
S
+
))
?
.
*
(
?
:
metric
(
\
d
+
))
?
', line)
if not match:
raise RuntimeError("Invalid output from `ip route'
:
`
%
s
'" % line)
tipe = match.group(1) or "unicast"
prefix = match.group(2)
#src = match.group(3)
nexthop = match.group(4)
interface = ifdata[match.group(5) or "lo"]
metric = match.group(6)
if prefix == "default" or re.search(r'
/
0
$
', prefix):
prefix = None
prefix_len = 0
else:
match = re.match(r'
([
0
-
9
a
-
f
:.]
+
)(
?
:
/
(
\
d
+
))
?$
', prefix)
prefix = match.group(1)
prefix_len = int(match.group(2) or 32)
ret.append(route(tipe, prefix, prefix_len, nexthop, interface.index,
metric))
return ret
get_all_route_data.func_code = _get_all_route_data.func_code
interface__init__ = interface.__init__
def __init__(self, *args, **kw):
interface__init__(self, *args, **kw)
if self.name:
self.name = self.name.split('
@
',1)[0]
interface.__init__ = __init__
get_addr_data.orig = function(get_addr_data.func_code,
get_addr_data.func_globals)
def _get_addr_data():
byidx, bynam = get_addr_data.orig()
return byidx, {name.split('
@
',1)[0]: a for name, a in bynam.iteritems()}
get_addr_data.func_code = _get_addr_data.func_code
@staticmethod
def _gen_if_name():
n = Interface._gen_next_id()
# Max 15 chars
# XXX: We truncate pid to not exceed IFNAMSIZ on systems with 32-bits pids
# but we should find something better to avoid possible collision.
return "NETNSif-%.4x%.3x" % (os.getpid() % 0xffff, n)
Interface._gen_if_name = _gen_if_name
@staticmethod
def _gen_br_name():
n = Switch._gen_next_id()
# XXX: same as for _gen_if_name
return "NETNSbr-%.4x%.3x" % (os.getpid() % 0xffff, n)
Switch._gen_br_name = _gen_br_name
demo/py
View file @
e80ca878
#!/usr/bin/env python
#!/usr/bin/env python
3
def
__file__
():
def
__file__
():
import
argparse
,
os
,
sys
import
argparse
,
os
,
sys
sys
.
dont_write_bytecode
=
True
sys
.
dont_write_bytecode
=
True
...
@@ -30,4 +30,5 @@ def __file__():
...
@@ -30,4 +30,5 @@ def __file__():
return
os
.
path
.
join
(
sys
.
path
[
0
],
sys
.
argv
[
1
])
return
os
.
path
.
join
(
sys
.
path
[
0
],
sys
.
argv
[
1
])
__file__
=
__file__
()
__file__
=
__file__
()
execfile
(
__file__
)
with
open
(
__file__
)
as
f
:
exec
(
compile
(
f
.
read
(),
__file__
,
'exec'
))
demo/test_hmac.py
View file @
e80ca878
...
@@ -34,7 +34,7 @@ def checkHMAC(db, machines):
...
@@ -34,7 +34,7 @@ def checkHMAC(db, machines):
else
:
else
:
i
=
0
if
hmac
[
0
]
else
1
i
=
0
if
hmac
[
0
]
else
1
if
hmac
[
i
]
!=
sign
or
hmac
[
i
+
1
]
!=
accept
:
if
hmac
[
i
]
!=
sign
or
hmac
[
i
+
1
]
!=
accept
:
print
'HMAC config wrong for in %s'
%
args
print
(
'HMAC config wrong for in %s'
%
args
)
rc
=
False
rc
=
False
if
rc
:
if
rc
:
print
(
'All nodes use Babel with the correct HMAC configuration'
)
print
(
'All nodes use Babel with the correct HMAC configuration'
)
...
...
draft/re6st-cn
View file @
e80ca878
...
@@ -5,7 +5,7 @@ if 're6st' not in sys.modules:
...
@@ -5,7 +5,7 @@ if 're6st' not in sys.modules:
from
re6st
import
utils
,
x509
from
re6st
import
utils
,
x509
from
OpenSSL
import
crypto
from
OpenSSL
import
crypto
with
open
(
"/etc/re6stnet/ca.crt"
)
as
f
:
with
open
(
"/etc/re6stnet/ca.crt"
,
"rb"
)
as
f
:
ca
=
crypto
.
load_certificate
(
crypto
.
FILETYPE_PEM
,
f
.
read
())
ca
=
crypto
.
load_certificate
(
crypto
.
FILETYPE_PEM
,
f
.
read
())
network
=
x509
.
networkFromCa
(
ca
)
network
=
x509
.
networkFromCa
(
ca
)
...
...
re6st/cache.py
View file @
e80ca878
import
json
,
logging
,
os
,
sqlite3
,
socket
,
subprocess
,
sys
,
time
,
zlib
import
base64
,
json
,
logging
,
os
,
sqlite3
,
socket
,
subprocess
,
sys
,
time
,
zlib
from
itertools
import
chain
from
itertools
import
chain
from
.registry
import
RegistryClient
from
.registry
import
RegistryClient
from
.
import
utils
,
version
,
x509
from
.
import
utils
,
version
,
x509
class
Cache
(
object
)
:
class
Cache
:
def
__init__
(
self
,
db_path
,
registry
,
c
ert
,
db_size
=
200
):
def
__init__
(
self
,
db_path
:
str
,
registry
,
cert
:
x509
.
C
ert
,
db_size
=
200
):
self
.
_prefix
=
cert
.
prefix
self
.
_prefix
=
cert
.
prefix
self
.
_db_size
=
db_size
self
.
_db_size
=
db_size
self
.
_decrypt
=
cert
.
decrypt
self
.
_decrypt
=
cert
.
decrypt
...
@@ -50,7 +50,7 @@ class Cache(object):
...
@@ -50,7 +50,7 @@ class Cache(object):
self
.
warnProtocol
()
self
.
warnProtocol
()
logging
.
info
(
"Cache initialized."
)
logging
.
info
(
"Cache initialized."
)
def
_open
(
self
,
path
)
:
def
_open
(
self
,
path
:
str
)
->
sqlite3
.
Connection
:
db
=
sqlite3
.
connect
(
path
,
isolation_level
=
None
)
db
=
sqlite3
.
connect
(
path
,
isolation_level
=
None
)
db
.
text_factory
=
str
db
.
text_factory
=
str
db
.
execute
(
"PRAGMA synchronous = OFF"
)
db
.
execute
(
"PRAGMA synchronous = OFF"
)
...
@@ -64,9 +64,8 @@ class Cache(object):
...
@@ -64,9 +64,8 @@ class Cache(object):
return
db
return
db
@
staticmethod
@
staticmethod
def
_selectConfig
(
execute
):
# BBB: blob
def
_selectConfig
(
execute
):
return
((
k
,
str
(
v
)
if
type
(
v
)
is
buffer
else
v
)
return
execute
(
"SELECT * FROM config"
)
for
k
,
v
in
execute
(
"SELECT * FROM config"
))
def
_loadConfig
(
self
,
config
):
def
_loadConfig
(
self
,
config
):
cls
=
self
.
__class__
cls
=
self
.
__class__
...
@@ -89,24 +88,23 @@ class Cache(object):
...
@@ -89,24 +88,23 @@ class Cache(object):
logging
.
info
(
"Getting new network parameters from registry..."
)
logging
.
info
(
"Getting new network parameters from registry..."
)
try
:
try
:
# TODO: When possible, the registry should be queried via the re6st.
# TODO: When possible, the registry should be queried via the re6st.
x
=
json
.
loads
(
zlib
.
decompress
(
network_config
=
self
.
_registry
.
getNetworkConfig
(
self
.
_prefix
)
self
.
_registry
.
getNetworkConfig
(
self
.
_prefix
)))
logging
.
debug
(
'getNetworkConfig result: %r'
,
network_config
)
base64
=
x
.
pop
(
''
,
())
x
=
json
.
loads
(
zlib
.
decompress
(
network_config
))
base64_list
=
x
.
pop
(
''
,
())
config
=
{}
config
=
{}
for
k
,
v
in
x
.
ite
rite
ms
():
for
k
,
v
in
x
.
items
():
k
=
str
(
k
)
k
=
str
(
k
)
if
k
.
startswith
(
'babel_hmac'
):
if
k
.
startswith
(
'babel_hmac'
):
if
v
:
if
v
:
v
=
self
.
_decrypt
(
v
.
decode
(
'base64'
))
v
=
self
.
_decrypt
(
base64
.
b64decode
(
v
))
elif
k
in
base64
:
elif
k
in
base64_list
:
v
=
v
.
decode
(
'base64'
)
v
=
base64
.
b64decode
(
v
)
elif
type
(
v
)
is
unicode
:
v
=
str
(
v
)
elif
isinstance
(
v
,
(
list
,
dict
)):
elif
isinstance
(
v
,
(
list
,
dict
)):
k
+=
':json'
k
+=
':json'
v
=
json
.
dumps
(
v
)
v
=
json
.
dumps
(
v
)
config
[
k
]
=
v
config
[
k
]
=
v
except
socket
.
error
,
e
:
except
socket
.
error
as
e
:
logging
.
warning
(
e
)
logging
.
warning
(
e
)
return
return
except
Exception
:
except
Exception
:
...
@@ -130,16 +128,12 @@ class Cache(object):
...
@@ -130,16 +128,12 @@ class Cache(object):
remove
.
append
(
k
)
remove
.
append
(
k
)
db
.
execute
(
"DELETE FROM config WHERE name in ('%s')"
db
.
execute
(
"DELETE FROM config WHERE name in ('%s')"
%
"','"
.
join
(
remove
))
%
"','"
.
join
(
remove
))
# BBB: Use buffer because of http://bugs.python.org/issue13676
# on Python 2.6
db
.
executemany
(
"INSERT OR REPLACE INTO config VALUES(?,?)"
,
db
.
executemany
(
"INSERT OR REPLACE INTO config VALUES(?,?)"
,
((
k
,
buffer
(
v
)
if
k
in
base64
or
config
.
items
())
k
.
startswith
(
'babel_hmac'
)
else
v
)
self
.
_loadConfig
(
config
.
items
())
for
k
,
v
in
config
.
iteritems
()))
self
.
_loadConfig
(
config
.
iteritems
())
return
[
k
[:
-
5
]
if
k
.
endswith
(
':json'
)
else
k
return
[
k
[:
-
5
]
if
k
.
endswith
(
':json'
)
else
k
for
k
in
chain
(
remove
,
(
k
for
k
in
chain
(
remove
,
(
k
for
k
,
v
in
config
.
ite
rite
ms
()
for
k
,
v
in
config
.
items
()
if
k
not
in
old
or
old
[
k
]
!=
v
))]
if
k
not
in
old
or
old
[
k
]
!=
v
))]
def
warnProtocol
(
self
):
def
warnProtocol
(
self
):
...
@@ -147,7 +141,7 @@ class Cache(object):
...
@@ -147,7 +141,7 @@ class Cache(object):
logging
.
warning
(
"There's a new version of re6stnet:"
logging
.
warning
(
"There's a new version of re6stnet:"
" you should update."
)
" you should update."
)
def
getDh
(
self
,
path
):
def
getDh
(
self
,
path
:
str
):
# We'd like to do a full check here but
# We'd like to do a full check here but
# from OpenSSL import SSL
# from OpenSSL import SSL
# SSL.Context(SSL.TLSv1_METHOD).load_tmp_dh(path)
# SSL.Context(SSL.TLSv1_METHOD).load_tmp_dh(path)
...
@@ -179,11 +173,11 @@ class Cache(object):
...
@@ -179,11 +173,11 @@ class Cache(object):
logging
.
trace
(
"- %s: %s%s"
,
prefix
,
address
,
logging
.
trace
(
"- %s: %s%s"
,
prefix
,
address
,
' (blacklisted)'
if
_try
else
''
)
' (blacklisted)'
if
_try
else
''
)
def
cacheMinimize
(
self
,
size
):
def
cacheMinimize
(
self
,
size
:
int
):
with
self
.
_db
:
with
self
.
_db
:
self
.
_cacheMinimize
(
size
)
self
.
_cacheMinimize
(
size
)
def
_cacheMinimize
(
self
,
size
):
def
_cacheMinimize
(
self
,
size
:
int
):
a
=
self
.
_db
.
execute
(
a
=
self
.
_db
.
execute
(
"SELECT peer FROM volatile.stat ORDER BY try, RANDOM() LIMIT ?,-1"
,
"SELECT peer FROM volatile.stat ORDER BY try, RANDOM() LIMIT ?,-1"
,
(
size
,)).
fetchall
()
(
size
,)).
fetchall
()
...
@@ -192,26 +186,26 @@ class Cache(object):
...
@@ -192,26 +186,26 @@ class Cache(object):
q
(
"DELETE FROM peer WHERE prefix IN (?)"
,
a
)
q
(
"DELETE FROM peer WHERE prefix IN (?)"
,
a
)
q
(
"DELETE FROM volatile.stat WHERE peer IN (?)"
,
a
)
q
(
"DELETE FROM volatile.stat WHERE peer IN (?)"
,
a
)
def
connecting
(
self
,
prefix
,
connecting
):
def
connecting
(
self
,
prefix
:
str
,
connecting
:
bool
):
self
.
_db
.
execute
(
"UPDATE volatile.stat SET try=? WHERE peer=?"
,
self
.
_db
.
execute
(
"UPDATE volatile.stat SET try=? WHERE peer=?"
,
(
connecting
,
prefix
))
(
connecting
,
prefix
))
def
resetConnecting
(
self
):
def
resetConnecting
(
self
):
self
.
_db
.
execute
(
"UPDATE volatile.stat SET try=0"
)
self
.
_db
.
execute
(
"UPDATE volatile.stat SET try=0"
)
def
getAddress
(
self
,
prefix
)
:
def
getAddress
(
self
,
prefix
:
str
)
->
bool
:
r
=
self
.
_db
.
execute
(
"SELECT address FROM peer, volatile.stat"
r
=
self
.
_db
.
execute
(
"SELECT address FROM peer, volatile.stat"
" WHERE prefix=? AND prefix=peer AND try=0"
,
" WHERE prefix=? AND prefix=peer AND try=0"
,
(
prefix
,)).
fetchone
()
(
prefix
,)).
fetchone
()
return
r
and
r
[
0
]
return
r
and
r
[
0
]
@
property
@
property
def
my_address
(
self
):
def
my_address
(
self
)
->
str
:
for
x
,
in
self
.
_db
.
execute
(
"SELECT address FROM peer WHERE prefix=''"
):
for
x
,
in
self
.
_db
.
execute
(
"SELECT address FROM peer WHERE prefix=''"
):
return
x
return
x
@
my_address
.
setter
@
my_address
.
setter
def
my_address
(
self
,
value
):
def
my_address
(
self
,
value
:
str
):
if
value
:
if
value
:
with
self
.
_db
as
db
:
with
self
.
_db
as
db
:
db
.
execute
(
"INSERT OR REPLACE INTO peer VALUES ('', ?)"
,
db
.
execute
(
"INSERT OR REPLACE INTO peer VALUES ('', ?)"
,
...
@@ -229,18 +223,20 @@ class Cache(object):
...
@@ -229,18 +223,20 @@ class Cache(object):
# IOW, one should probably always put our own address there.
# IOW, one should probably always put our own address there.
_get_peer_sql
=
"SELECT %s FROM peer, volatile.stat"
\
_get_peer_sql
=
"SELECT %s FROM peer, volatile.stat"
\
" WHERE prefix=peer AND prefix!=? AND try=?"
" WHERE prefix=peer AND prefix!=? AND try=?"
def
getPeerList
(
self
,
failed
=
0
,
__sql
=
_get_peer_sql
%
"prefix, address"
def
getPeerList
(
self
,
failed
=
False
,
__sql
=
_get_peer_sql
%
"prefix, address"
+
" ORDER BY RANDOM()"
):
+
" ORDER BY RANDOM()"
):
return
self
.
_db
.
execute
(
__sql
,
(
self
.
_prefix
,
failed
))
return
self
.
_db
.
execute
(
__sql
,
(
self
.
_prefix
,
failed
))
def
getPeerCount
(
self
,
failed
=
0
,
__sql
=
_get_peer_sql
%
"COUNT(*)"
):
def
getPeerCount
(
self
,
failed
=
False
,
__sql
=
_get_peer_sql
%
"COUNT(*)"
)
\
->
int
:
return
self
.
_db
.
execute
(
__sql
,
(
self
.
_prefix
,
failed
)).
next
()[
0
]
return
self
.
_db
.
execute
(
__sql
,
(
self
.
_prefix
,
failed
)).
next
()[
0
]
def
getBootstrapPeer
(
self
):
def
getBootstrapPeer
(
self
)
->
tuple
[
str
,
str
]
:
logging
.
info
(
'Getting Boot peer...'
)
logging
.
info
(
'Getting Boot peer...'
)
try
:
try
:
bootpeer
=
self
.
_registry
.
getBootstrapPeer
(
self
.
_prefix
)
bootpeer
=
self
.
_registry
.
getBootstrapPeer
(
self
.
_prefix
)
prefix
,
address
=
self
.
_decrypt
(
bootpeer
).
split
()
prefix
,
address
=
self
.
_decrypt
(
bootpeer
).
decode
().
split
()
except
(
socket
.
error
,
subprocess
.
CalledProcessError
,
ValueError
)
,
e
:
except
(
socket
.
error
,
subprocess
.
CalledProcessError
,
ValueError
)
as
e
:
logging
.
warning
(
'Failed to bootstrap (%s)'
,
logging
.
warning
(
'Failed to bootstrap (%s)'
,
e
if
bootpeer
else
'no peer returned'
)
e
if
bootpeer
else
'no peer returned'
)
else
:
else
:
...
@@ -249,7 +245,7 @@ class Cache(object):
...
@@ -249,7 +245,7 @@ class Cache(object):
return
prefix
,
address
return
prefix
,
address
logging
.
warning
(
'Buggy registry sent us our own address'
)
logging
.
warning
(
'Buggy registry sent us our own address'
)
def
addPeer
(
self
,
prefix
,
address
,
set_preferred
=
False
):
def
addPeer
(
self
,
prefix
:
str
,
address
:
str
,
set_preferred
=
False
):
logging
.
debug
(
'Adding peer %s: %s'
,
prefix
,
address
)
logging
.
debug
(
'Adding peer %s: %s'
,
prefix
,
address
)
with
self
.
_db
:
with
self
.
_db
:
q
=
self
.
_db
.
execute
q
=
self
.
_db
.
execute
...
@@ -273,8 +269,8 @@ class Cache(object):
...
@@ -273,8 +269,8 @@ class Cache(object):
q
(
"INSERT OR REPLACE INTO peer VALUES (?,?)"
,
(
prefix
,
address
))
q
(
"INSERT OR REPLACE INTO peer VALUES (?,?)"
,
(
prefix
,
address
))
q
(
"INSERT OR REPLACE INTO volatile.stat VALUES (?,0)"
,
(
prefix
,))
q
(
"INSERT OR REPLACE INTO volatile.stat VALUES (?,0)"
,
(
prefix
,))
def
getCountry
(
self
,
ip
)
:
def
getCountry
(
self
,
ip
:
str
)
->
str
:
try
:
try
:
return
self
.
_registry
.
getCountry
(
self
.
_prefix
,
ip
)
return
self
.
_registry
.
getCountry
(
self
.
_prefix
,
ip
)
.
decode
()
except
socket
.
error
,
e
:
except
socket
.
error
as
e
:
logging
.
warning
(
'Failed to get country (%s)'
,
ip
)
logging
.
warning
(
'Failed to get country (%s)'
,
ip
)
re6st/cli/conf.py
View file @
e80ca878
#!/usr/bin/
python2
#!/usr/bin/
env python3
import
argparse
,
atexit
,
binascii
,
errno
,
hashlib
import
argparse
,
atexit
,
binascii
,
errno
,
hashlib
import
os
,
subprocess
,
sqlite3
,
sys
,
time
import
os
,
subprocess
,
sqlite3
,
sys
,
time
from
OpenSSL
import
crypto
from
OpenSSL
import
crypto
...
@@ -6,14 +6,14 @@ if 're6st' not in sys.modules:
...
@@ -6,14 +6,14 @@ if 're6st' not in sys.modules:
sys
.
path
[
0
]
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
sys
.
path
[
0
]))
sys
.
path
[
0
]
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
sys
.
path
[
0
]))
from
re6st
import
registry
,
utils
,
x509
from
re6st
import
registry
,
utils
,
x509
def
create
(
path
,
text
=
None
,
mode
=
0666
):
def
create
(
path
,
text
=
None
,
mode
=
0
o
666
):
fd
=
os
.
open
(
path
,
os
.
O_CREAT
|
os
.
O_WRONLY
|
os
.
O_TRUNC
,
mode
)
fd
=
os
.
open
(
path
,
os
.
O_CREAT
|
os
.
O_WRONLY
|
os
.
O_TRUNC
,
mode
)
try
:
try
:
os
.
write
(
fd
,
text
)
os
.
write
(
fd
,
text
)
finally
:
finally
:
os
.
close
(
fd
)
os
.
close
(
fd
)
def
loadCert
(
pem
):
def
loadCert
(
pem
:
bytes
):
return
crypto
.
load_certificate
(
crypto
.
FILETYPE_PEM
,
pem
)
return
crypto
.
load_certificate
(
crypto
.
FILETYPE_PEM
,
pem
)
def
main
():
def
main
():
...
@@ -68,17 +68,18 @@ def main():
...
@@ -68,17 +68,18 @@ def main():
fingerprint
=
binascii
.
a2b_hex
(
fingerprint
)
fingerprint
=
binascii
.
a2b_hex
(
fingerprint
)
if
hashlib
.
new
(
alg
).
digest_size
!=
len
(
fingerprint
):
if
hashlib
.
new
(
alg
).
digest_size
!=
len
(
fingerprint
):
raise
ValueError
(
"wrong size"
)
raise
ValueError
(
"wrong size"
)
except
StandardError
,
e
:
except
Exception
as
e
:
parser
.
error
(
"invalid fingerprint: %s"
%
e
)
parser
.
error
(
"invalid fingerprint: %s"
%
e
)
if
x509
.
fingerprint
(
ca
,
alg
).
digest
()
!=
fingerprint
:
if
x509
.
fingerprint
(
ca
,
alg
).
digest
()
!=
fingerprint
:
sys
.
exit
(
"CA fingerprint doesn't match"
)
sys
.
exit
(
"CA fingerprint doesn't match"
)
else
:
else
:
print
"WARNING: it is strongly recommended to use --fingerprint option."
print
(
"WARNING: it is strongly recommended to use --fingerprint option."
)
network
=
x509
.
networkFromCa
(
ca
)
network
=
x509
.
networkFromCa
(
ca
)
if
config
.
is_needed
:
if
config
.
is_needed
:
route
,
err
=
subprocess
.
Popen
((
'ip'
,
'-6'
,
'-o'
,
'route'
,
'get'
,
with
subprocess
.
Popen
((
'ip'
,
'-6'
,
'-o'
,
'route'
,
'get'
,
utils
.
ipFromBin
(
network
)),
utils
.
ipFromBin
(
network
)),
stdout
=
subprocess
.
PIPE
).
communicate
()
stdout
=
subprocess
.
PIPE
)
as
proc
:
route
,
err
=
proc
.
communicate
()
sys
.
exit
(
err
or
route
and
sys
.
exit
(
err
or
route
and
utils
.
binFromIp
(
route
.
split
()[
8
]).
startswith
(
network
))
utils
.
binFromIp
(
route
.
split
()[
8
]).
startswith
(
network
))
...
@@ -89,19 +90,20 @@ def main():
...
@@ -89,19 +90,20 @@ def main():
reserved
=
'CN'
,
'serial'
reserved
=
'CN'
,
'serial'
req
=
crypto
.
X509Req
()
req
=
crypto
.
X509Req
()
try
:
try
:
with
open
(
cert_path
)
as
f
:
with
open
(
cert_path
,
"rb"
)
as
f
:
cert
=
loadCert
(
f
.
read
())
cert
=
loadCert
(
f
.
read
())
components
=
dict
(
cert
.
get_subject
().
get_components
())
components
=
\
{
k
.
decode
():
v
for
k
,
v
in
cert
.
get_subject
().
get_components
()}
for
k
in
reserved
:
for
k
in
reserved
:
components
.
pop
(
k
,
None
)
components
.
pop
(
k
,
None
)
except
IOError
,
e
:
except
IOError
as
e
:
if
e
.
errno
!=
errno
.
ENOENT
:
if
e
.
errno
!=
errno
.
ENOENT
:
raise
raise
components
=
{}
components
=
{}
if
config
.
req
:
if
config
.
req
:
components
.
update
(
config
.
req
)
components
.
update
(
config
.
req
)
subj
=
req
.
get_subject
()
subj
=
req
.
get_subject
()
for
k
,
v
in
components
.
ite
rite
ms
():
for
k
,
v
in
components
.
items
():
if
k
in
reserved
:
if
k
in
reserved
:
sys
.
exit
(
k
+
" field is reserved."
)
sys
.
exit
(
k
+
" field is reserved."
)
if
v
:
if
v
:
...
@@ -116,35 +118,35 @@ def main():
...
@@ -116,35 +118,35 @@ def main():
token
=
''
token
=
''
elif
not
token
:
elif
not
token
:
if
not
config
.
email
:
if
not
config
.
email
:
config
.
email
=
raw_
input
(
'Please enter your email address: '
)
config
.
email
=
input
(
'Please enter your email address: '
)
s
.
requestToken
(
config
.
email
)
s
.
requestToken
(
config
.
email
)
token_advice
=
"Use --token to retry without asking a new token
\
n
"
token_advice
=
"Use --token to retry without asking a new token
\
n
"
while
not
token
:
while
not
token
:
token
=
raw_
input
(
'Please enter your token: '
)
token
=
input
(
'Please enter your token: '
)
try
:
try
:
with
open
(
key_path
)
as
f
:
with
open
(
key_path
)
as
f
:
pkey
=
crypto
.
load_privatekey
(
crypto
.
FILETYPE_PEM
,
f
.
read
())
pkey
=
crypto
.
load_privatekey
(
crypto
.
FILETYPE_PEM
,
f
.
read
())
key
=
None
key
=
None
print
"Reusing existing key."
print
(
"Reusing existing key."
)
except
IOError
,
e
:
except
IOError
as
e
:
if
e
.
errno
!=
errno
.
ENOENT
:
if
e
.
errno
!=
errno
.
ENOENT
:
raise
raise
bits
=
ca
.
get_pubkey
().
bits
()
bits
=
ca
.
get_pubkey
().
bits
()
print
"Generating %s-bit key ..."
%
bits
print
(
"Generating %s-bit key ..."
%
bits
)
pkey
=
crypto
.
PKey
()
pkey
=
crypto
.
PKey
()
pkey
.
generate_key
(
crypto
.
TYPE_RSA
,
bits
)
pkey
.
generate_key
(
crypto
.
TYPE_RSA
,
bits
)
key
=
crypto
.
dump_privatekey
(
crypto
.
FILETYPE_PEM
,
pkey
)
key
=
crypto
.
dump_privatekey
(
crypto
.
FILETYPE_PEM
,
pkey
)
create
(
key_path
,
key
,
0600
)
create
(
key_path
,
key
,
0
o
600
)
req
.
set_pubkey
(
pkey
)
req
.
set_pubkey
(
pkey
)
req
.
sign
(
pkey
,
'sha512'
)
req
.
sign
(
pkey
,
'sha512'
)
req
=
crypto
.
dump_certificate_request
(
crypto
.
FILETYPE_PEM
,
req
)
req
=
crypto
.
dump_certificate_request
(
crypto
.
FILETYPE_PEM
,
req
)
.
decode
()
# First make sure we can open certificate file for writing,
# First make sure we can open certificate file for writing,
# to avoid using our token for nothing.
# to avoid using our token for nothing.
cert_fd
=
os
.
open
(
cert_path
,
os
.
O_CREAT
|
os
.
O_WRONLY
,
0666
)
cert_fd
=
os
.
open
(
cert_path
,
os
.
O_CREAT
|
os
.
O_WRONLY
,
0
o
666
)
print
"Requesting certificate ..."
print
(
"Requesting certificate ..."
)
if
config
.
location
:
if
config
.
location
:
cert
=
s
.
requestCertificate
(
token
,
req
,
location
=
config
.
location
)
cert
=
s
.
requestCertificate
(
token
,
req
,
location
=
config
.
location
)
else
:
else
:
...
@@ -173,7 +175,7 @@ def main():
...
@@ -173,7 +175,7 @@ def main():
key_path
))
key_path
))
if
not
os
.
path
.
lexists
(
conf_path
):
if
not
os
.
path
.
lexists
(
conf_path
):
create
(
conf_path
,
"""
\
create
(
conf_path
,
(
"""
\
registry %s
registry %s
ca %s
ca %s
cert %s
cert %s
...
@@ -187,14 +189,14 @@ key %s
...
@@ -187,14 +189,14 @@ key %s
#O--verb
#O--verb
#O3
#O3
"""
%
(
config
.
registry
,
ca_path
,
cert_path
,
key_path
,
"""
%
(
config
.
registry
,
ca_path
,
cert_path
,
key_path
,
(
'country '
+
config
.
location
.
split
(
','
,
1
)[
0
])
\
(
'country '
+
config
.
location
.
split
(
','
,
1
)[
0
])
if
config
.
location
else
''
))
if
config
.
location
else
''
))
.
encode
())
print
"Sample configuration file created."
print
(
"Sample configuration file created."
)
cn
=
x509
.
subnetFromCert
(
cert
)
cn
=
x509
.
subnetFromCert
(
cert
)
subnet
=
network
+
utils
.
binFromSubnet
(
cn
)
subnet
=
network
+
utils
.
binFromSubnet
(
cn
)
print
"Your subnet: %s/%u (CN=%s)"
\
print
(
"Your subnet: %s/%u (CN=%s)"
%
(
utils
.
ipFromBin
(
subnet
),
len
(
subnet
),
cn
)
%
(
utils
.
ipFromBin
(
subnet
),
len
(
subnet
),
cn
)
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
main
()
main
()
re6st/cli/node.py
View file @
e80ca878
#!/usr/bin/
python2
#!/usr/bin/
env python3
import
atexit
,
errno
,
logging
,
os
,
shutil
,
signal
import
atexit
,
errno
,
logging
,
os
,
shutil
,
signal
import
socket
,
struct
,
subprocess
,
sys
import
socket
,
struct
,
subprocess
,
sys
from
collections
import
deque
from
collections
import
deque
...
@@ -246,7 +246,7 @@ def main():
...
@@ -246,7 +246,7 @@ def main():
try
:
try
:
from
re6st.upnpigd
import
Forwarder
from
re6st.upnpigd
import
Forwarder
forwarder
=
Forwarder
(
're6stnet openvpn server'
)
forwarder
=
Forwarder
(
're6stnet openvpn server'
)
except
Exception
,
e
:
except
Exception
as
e
:
if
ipv4
:
if
ipv4
:
raise
raise
logging
.
info
(
"%s: assume we are not NATed"
,
e
)
logging
.
info
(
"%s: assume we are not NATed"
,
e
)
...
@@ -266,19 +266,13 @@ def main():
...
@@ -266,19 +266,13 @@ def main():
def
call
(
cmd
):
def
call
(
cmd
):
logging
.
debug
(
'%r'
,
cmd
)
logging
.
debug
(
'%r'
,
cmd
)
p
=
subprocess
.
Popen
(
cmd
,
stdout
=
subprocess
.
PIPE
,
return
subprocess
.
run
(
cmd
,
capture_output
=
True
,
check
=
True
).
stdout
stderr
=
subprocess
.
PIPE
)
def
ip4
(
object
:
str
,
*
args
):
stdout
,
stderr
=
p
.
communicate
()
if
p
.
returncode
:
raise
EnvironmentError
(
"%r failed with error %u
\
n
%s"
%
(
' '
.
join
(
cmd
),
p
.
returncode
,
stderr
))
return
stdout
def
ip4
(
object
,
*
args
):
args
=
[
'ip'
,
'-4'
,
object
,
'add'
]
+
list
(
args
)
args
=
[
'ip'
,
'-4'
,
object
,
'add'
]
+
list
(
args
)
call
(
args
)
call
(
args
)
args
[
3
]
=
'del'
args
[
3
]
=
'del'
cleanup
.
append
(
lambda
:
subprocess
.
call
(
args
))
cleanup
.
append
(
lambda
:
subprocess
.
call
(
args
))
def
ip
(
object
,
*
args
):
def
ip
(
object
:
str
,
*
args
):
args
=
[
'ip'
,
'-6'
,
object
,
'add'
]
+
list
(
args
)
args
=
[
'ip'
,
'-6'
,
object
,
'add'
]
+
list
(
args
)
call
(
args
)
call
(
args
)
args
[
3
]
=
'del'
args
[
3
]
=
'del'
...
@@ -299,7 +293,7 @@ def main():
...
@@ -299,7 +293,7 @@ def main():
timeout
=
4
*
cache
.
hello
timeout
=
4
*
cache
.
hello
cleanup
=
[
lambda
:
cache
.
cacheMinimize
(
config
.
client_count
),
cleanup
=
[
lambda
:
cache
.
cacheMinimize
(
config
.
client_count
),
lambda
:
shutil
.
rmtree
(
config
.
run
,
True
)]
lambda
:
shutil
.
rmtree
(
config
.
run
,
True
)]
utils
.
makedirs
(
config
.
run
,
0700
)
utils
.
makedirs
(
config
.
run
,
0
o
700
)
control_socket
=
os
.
path
.
join
(
config
.
run
,
'babeld.sock'
)
control_socket
=
os
.
path
.
join
(
config
.
run
,
'babeld.sock'
)
if
config
.
client_count
and
not
config
.
client
:
if
config
.
client_count
and
not
config
.
client
:
tunnel_manager
=
tunnel
.
TunnelManager
(
control_socket
,
tunnel_manager
=
tunnel
.
TunnelManager
(
control_socket
,
...
@@ -362,7 +356,7 @@ def main():
...
@@ -362,7 +356,7 @@ def main():
if
not
dh
:
if
not
dh
:
dh
=
os
.
path
.
join
(
config
.
state
,
"dh.pem"
)
dh
=
os
.
path
.
join
(
config
.
state
,
"dh.pem"
)
cache
.
getDh
(
dh
)
cache
.
getDh
(
dh
)
for
iface
,
(
port
,
proto
)
in
server_tunnels
.
ite
rite
ms
():
for
iface
,
(
port
,
proto
)
in
server_tunnels
.
items
():
r
,
x
=
socket
.
socketpair
(
socket
.
AF_UNIX
,
socket
.
SOCK_DGRAM
)
r
,
x
=
socket
.
socketpair
(
socket
.
AF_UNIX
,
socket
.
SOCK_DGRAM
)
utils
.
setCloexec
(
r
)
utils
.
setCloexec
(
r
)
cleanup
.
append
(
plib
.
server
(
iface
,
config
.
max_clients
,
cleanup
.
append
(
plib
.
server
(
iface
,
config
.
max_clients
,
...
@@ -442,7 +436,7 @@ def main():
...
@@ -442,7 +436,7 @@ def main():
except
:
except
:
pass
pass
exit
.
release
()
exit
.
release
()
except
ReexecException
,
e
:
except
ReexecException
as
e
:
logging
.
info
(
e
)
logging
.
info
(
e
)
except
Exception
:
except
Exception
:
utils
.
log_exception
()
utils
.
log_exception
()
...
@@ -455,7 +449,7 @@ def main():
...
@@ -455,7 +449,7 @@ def main():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
try
:
try
:
main
()
main
()
except
SystemExit
,
e
:
except
SystemExit
as
e
:
if
type
(
e
.
code
)
is
str
:
if
type
(
e
.
code
)
is
str
:
if
hasattr
(
logging
,
'trace'
):
# utils.setupLog called
if
hasattr
(
logging
,
'trace'
):
# utils.setupLog called
logging
.
critical
(
e
.
code
)
logging
.
critical
(
e
.
code
)
...
...
re6st/cli/registry.py
View file @
e80ca878
#!/usr/bin/
python2
#!/usr/bin/
env python3
import
http
lib
,
logging
,
os
,
socket
,
sys
import
http
.client
,
logging
,
os
,
socket
,
sys
from
BaseHTTPS
erver
import
BaseHTTPRequestHandler
from
http.s
erver
import
BaseHTTPRequestHandler
from
SocketS
erver
import
ThreadingTCPServer
from
sockets
erver
import
ThreadingTCPServer
from
urlparse
import
parse_qsl
from
url
lib.
parse
import
parse_qsl
if
're6st'
not
in
sys
.
modules
:
if
're6st'
not
in
sys
.
modules
:
sys
.
path
[
0
]
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
sys
.
path
[
0
]))
sys
.
path
[
0
]
=
os
.
path
.
dirname
(
os
.
path
.
dirname
(
sys
.
path
[
0
]))
from
re6st
import
registry
,
utils
,
version
from
re6st
import
registry
,
utils
,
version
...
@@ -29,14 +29,14 @@ class RequestHandler(BaseHTTPRequestHandler):
...
@@ -29,14 +29,14 @@ class RequestHandler(BaseHTTPRequestHandler):
path
=
self
.
path
path
=
self
.
path
query
=
{}
query
=
{}
else
:
else
:
query
=
dict
(
parse_qsl
(
query
,
keep_blank_values
=
1
,
query
=
dict
(
parse_qsl
(
query
,
keep_blank_values
=
True
,
strict_parsing
=
1
))
strict_parsing
=
True
))
_
,
path
=
path
.
split
(
'/'
)
_
,
path
=
path
.
split
(
'/'
)
if
not
_
:
if
not
_
:
return
self
.
server
.
handle_request
(
self
,
path
,
query
)
return
self
.
server
.
handle_request
(
self
,
path
,
query
)
except
Exception
:
except
Exception
:
logging
.
info
(
self
.
requestline
,
exc_info
=
1
)
logging
.
info
(
self
.
requestline
,
exc_info
=
True
)
self
.
send_error
(
http
lib
.
BAD_REQUEST
)
self
.
send_error
(
http
.
client
.
BAD_REQUEST
)
def
log_error
(
*
args
):
def
log_error
(
*
args
):
pass
pass
...
...
re6st/ctl.py
View file @
e80ca878
...
@@ -5,7 +5,7 @@ from . import utils
...
@@ -5,7 +5,7 @@ from . import utils
uint16
=
struct
.
Struct
(
"!H"
)
uint16
=
struct
.
Struct
(
"!H"
)
header
=
struct
.
Struct
(
"!HI"
)
header
=
struct
.
Struct
(
"!HI"
)
class
Struct
(
object
)
:
class
Struct
:
def
__init__
(
self
,
format
,
*
args
):
def
__init__
(
self
,
format
,
*
args
):
if
args
:
if
args
:
...
@@ -29,39 +29,39 @@ class Struct(object):
...
@@ -29,39 +29,39 @@ class Struct(object):
self
.
encode
=
encode
self
.
encode
=
encode
self
.
decode
=
decode
self
.
decode
=
decode
class
Array
(
object
)
:
class
Array
:
def
__init__
(
self
,
item
):
def
__init__
(
self
,
item
):
self
.
_item
=
item
self
.
_item
=
item
def
encode
(
self
,
buffer
,
value
):
def
encode
(
self
,
buffer
:
bytes
,
value
:
list
):
buffer
+=
uint16
.
pack
(
len
(
value
))
buffer
+=
uint16
.
pack
(
len
(
value
))
encode
=
self
.
_item
.
encode
encode
=
self
.
_item
.
encode
for
value
in
value
:
for
value
in
value
:
encode
(
buffer
,
value
)
encode
(
buffer
,
value
)
def
decode
(
self
,
buffer
,
offset
=
0
)
:
def
decode
(
self
,
buffer
:
bytes
,
offset
=
0
)
->
tuple
[
int
,
list
]
:
r
=
[]
r
=
[]
o
=
offset
+
2
o
=
offset
+
2
decode
=
self
.
_item
.
decode
decode
=
self
.
_item
.
decode
for
i
in
x
range
(
*
uint16
.
unpack_from
(
buffer
,
offset
)):
for
i
in
range
(
*
uint16
.
unpack_from
(
buffer
,
offset
)):
o
,
x
=
decode
(
buffer
,
o
)
o
,
x
=
decode
(
buffer
,
o
)
r
.
append
(
x
)
r
.
append
(
x
)
return
o
,
r
return
o
,
r
class
String
(
object
)
:
class
String
:
@
staticmethod
@
staticmethod
def
encode
(
buffer
,
value
):
def
encode
(
buffer
:
bytes
,
value
:
str
):
buffer
+=
value
+
"
\
0
"
buffer
+=
value
.
encode
(
"utf-8"
)
+
b'
\
0
'
@
staticmethod
@
staticmethod
def
decode
(
buffer
,
offset
=
0
)
:
def
decode
(
buffer
:
bytes
,
offset
=
0
)
->
tuple
[
int
,
str
]
:
i
=
buffer
.
index
(
"
\
0
"
,
offset
)
i
=
buffer
.
index
(
0
,
offset
)
return
i
+
1
,
buffer
[
offset
:
i
]
return
i
+
1
,
buffer
[
offset
:
i
]
.
decode
(
"utf-8"
)
class
Buffer
(
object
)
:
class
Buffer
:
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_buf
=
bytearray
()
self
.
_buf
=
bytearray
()
...
@@ -104,21 +104,6 @@ class Buffer(object):
...
@@ -104,21 +104,6 @@ class Buffer(object):
self
.
_seek
(
r
)
self
.
_seek
(
r
)
return
value
return
value
try
:
# BBB: Python < 2.7.4 (http://bugs.python.org/issue10212)
uint16
.
unpack_from
(
bytearray
(
uint16
.
size
))
except
TypeError
:
def
unpack_from
(
self
,
struct
):
r
=
self
.
_r
x
=
r
+
struct
.
size
value
=
struct
.
unpack
(
buffer
(
self
.
_buf
)[
r
:
x
])
self
.
_seek
(
x
)
return
value
def
decode
(
self
,
decode
):
r
=
self
.
_r
size
,
value
=
decode
(
buffer
(
self
.
_buf
)[
r
:])
self
.
_seek
(
r
+
size
)
return
value
# writing
# writing
def
send
(
self
,
socket
,
*
args
):
def
send
(
self
,
socket
,
*
args
):
...
@@ -129,7 +114,7 @@ class Buffer(object):
...
@@ -129,7 +114,7 @@ class Buffer(object):
struct
.
pack_into
(
self
.
_buf
,
offset
,
*
args
)
struct
.
pack_into
(
self
.
_buf
,
offset
,
*
args
)
class
Packet
(
object
)
:
class
Packet
:
response_dict
=
{}
response_dict
=
{}
...
@@ -149,7 +134,7 @@ class Packet(object):
...
@@ -149,7 +134,7 @@ class Packet(object):
logging
.
trace
(
'send %s%r'
,
self
.
__class__
.
__name__
,
logging
.
trace
(
'send %s%r'
,
self
.
__class__
.
__name__
,
(
self
.
id
,)
+
self
.
args
)
(
self
.
id
,)
+
self
.
args
)
offset
=
len
(
buffer
)
offset
=
len
(
buffer
)
buffer
+=
'
\
0
'
*
header
.
size
buffer
+=
bytes
(
header
.
size
)
r
=
self
.
request
r
=
self
.
request
if
isinstance
(
r
,
Struct
):
if
isinstance
(
r
,
Struct
):
r
.
encode
(
buffer
,
self
.
args
)
r
.
encode
(
buffer
,
self
.
args
)
...
@@ -182,11 +167,11 @@ class ConnectionClosed(BabelException):
...
@@ -182,11 +167,11 @@ class ConnectionClosed(BabelException):
return
"connection to babeld closed (%s)"
%
self
.
args
return
"connection to babeld closed (%s)"
%
self
.
args
class
Babel
(
object
)
:
class
Babel
:
_decode
=
None
_decode
=
None
def
__init__
(
self
,
socket_path
,
handler
,
network
):
def
__init__
(
self
,
socket_path
:
str
,
handler
,
network
:
str
):
self
.
socket_path
=
socket_path
self
.
socket_path
=
socket_path
self
.
handler
=
handler
self
.
handler
=
handler
self
.
network
=
network
self
.
network
=
network
...
@@ -206,11 +191,11 @@ class Babel(object):
...
@@ -206,11 +191,11 @@ class Babel(object):
def
select
(
*
args
):
def
select
(
*
args
):
try
:
try
:
s
.
connect
(
self
.
socket_path
)
s
.
connect
(
self
.
socket_path
)
except
socket
.
error
,
e
:
except
socket
.
error
as
e
:
logging
.
debug
(
"Can't connect to %r (%r)"
,
self
.
socket_path
,
e
)
logging
.
debug
(
"Can't connect to %r (%r)"
,
self
.
socket_path
,
e
)
return
e
return
e
s
.
send
(
"
\
1
"
)
s
.
send
(
b'
\
1
'
)
s
.
setblocking
(
0
)
s
.
setblocking
(
False
)
del
self
.
select
del
self
.
select
self
.
socket
=
s
self
.
socket
=
s
return
self
.
select
(
*
args
)
return
self
.
select
(
*
args
)
...
@@ -267,15 +252,18 @@ class Babel(object):
...
@@ -267,15 +252,18 @@ class Babel(object):
unidentified
=
set
(
n
)
unidentified
=
set
(
n
)
self
.
neighbours
=
neighbours
=
{}
self
.
neighbours
=
neighbours
=
{}
a
=
len
(
self
.
network
)
a
=
len
(
self
.
network
)
logging
.
info
(
"Routes: %r"
,
routes
)
for
route
in
routes
:
for
route
in
routes
:
assert
route
.
flags
&
1
,
route
# installed
assert
route
.
flags
&
1
,
route
# installed
if
route
.
prefix
.
startswith
(
'
\
0
\
0
\
0
\
0
\
0
\
0
\
0
\
0
\
0
\
0
\
xff
\
xff
'
):
if
route
.
prefix
.
startswith
(
b'
\
0
\
0
\
0
\
0
\
0
\
0
\
0
\
0
\
0
\
0
\
xff
\
xff
'
):
logging
.
warning
(
"Ignoring IPv4 route: %r"
,
route
)
continue
continue
assert
route
.
neigh_address
==
route
.
nexthop
,
route
assert
route
.
neigh_address
==
route
.
nexthop
,
route
address
=
route
.
neigh_address
,
route
.
ifindex
address
=
route
.
neigh_address
,
route
.
ifindex
neigh_routes
=
n
[
address
]
neigh_routes
=
n
[
address
]
ip
=
utils
.
binFromRawIp
(
route
.
prefix
)
ip
=
utils
.
binFromRawIp
(
route
.
prefix
)
if
ip
[:
a
]
==
self
.
network
:
if
ip
[:
a
]
==
self
.
network
:
logging
.
debug
(
"Route is on the network: %r"
,
route
)
prefix
=
ip
[
a
:
route
.
plen
]
prefix
=
ip
[
a
:
route
.
plen
]
if
prefix
and
not
route
.
refmetric
:
if
prefix
and
not
route
.
refmetric
:
neighbours
[
prefix
]
=
neigh_routes
neighbours
[
prefix
]
=
neigh_routes
...
@@ -290,7 +278,9 @@ class Babel(object):
...
@@ -290,7 +278,9 @@ class Babel(object):
socket
.
inet_ntop
(
socket
.
AF_INET6
,
route
.
prefix
),
socket
.
inet_ntop
(
socket
.
AF_INET6
,
route
.
prefix
),
route
.
plen
)
route
.
plen
)
else
:
else
:
logging
.
debug
(
"Route is not on the network: %r"
,
route
)
prefix
=
None
prefix
=
None
logging
.
debug
(
"Adding route %r to %r"
,
route
,
neigh_routes
)
neigh_routes
[
1
][
prefix
]
=
route
neigh_routes
[
1
][
prefix
]
=
route
self
.
locked
.
clear
()
self
.
locked
.
clear
()
if
unidentified
:
if
unidentified
:
...
@@ -310,11 +300,11 @@ class Babel(object):
...
@@ -310,11 +300,11 @@ class Babel(object):
pass
pass
class
iterRoutes
(
object
)
:
class
iterRoutes
:
_waiting
=
True
_waiting
=
True
def
__new__
(
cls
,
control_socket
,
network
):
def
__new__
(
cls
,
control_socket
:
str
,
network
:
str
):
self
=
object
.
__new__
(
cls
)
self
=
object
.
__new__
(
cls
)
c
=
Babel
(
control_socket
,
self
,
network
)
c
=
Babel
(
control_socket
,
self
,
network
)
c
.
request_dump
()
c
.
request_dump
()
...
@@ -323,7 +313,7 @@ class iterRoutes(object):
...
@@ -323,7 +313,7 @@ class iterRoutes(object):
c
.
select
(
*
args
)
c
.
select
(
*
args
)
utils
.
select
(
*
args
)
utils
.
select
(
*
args
)
return
(
prefix
return
(
prefix
for
neigh_routes
in
c
.
neighbours
.
iter
values
()
for
neigh_routes
in
c
.
neighbours
.
values
()
for
prefix
in
neigh_routes
[
1
]
for
prefix
in
neigh_routes
[
1
]
if
prefix
)
if
prefix
)
...
...
re6st/debug.py
View file @
e80ca878
import
errno
,
os
,
socket
,
stat
,
threading
import
errno
,
os
,
socket
,
stat
,
threading
class
Socket
(
object
)
:
class
Socket
:
def
__init__
(
self
,
socket
):
def
__init__
(
self
,
socket
:
socket
.
socket
):
# In case that the default timeout is not None.
# In case that the default timeout is not None.
socket
.
settimeout
(
None
)
socket
.
settimeout
(
None
)
self
.
_socket
=
socket
self
.
_socket
=
socket
self
.
_buf
=
''
self
.
_buf
=
b
''
def
close
(
self
):
def
close
(
self
):
self
.
_socket
.
close
()
self
.
_socket
.
close
()
def
write
(
self
,
data
):
def
write
(
self
,
data
:
bytes
):
self
.
_socket
.
send
(
data
)
self
.
_socket
.
send
(
data
)
def
readline
(
self
):
def
readline
(
self
)
->
bytes
:
recv
=
self
.
_socket
.
recv
recv
=
self
.
_socket
.
recv
data
=
self
.
_buf
data
=
self
.
_buf
while
True
:
while
True
:
i
=
1
+
data
.
find
(
'
\
n
'
)
i
=
1
+
data
.
find
(
b
'
\
n
'
)
if
i
:
if
i
:
self
.
_buf
=
data
[
i
:]
self
.
_buf
=
data
[
i
:]
return
data
[:
i
]
return
data
[:
i
]
d
=
recv
(
4096
)
d
=
recv
(
4096
)
data
+=
d
data
+=
d
if
not
d
:
if
not
d
:
self
.
_buf
=
''
self
.
_buf
=
b
''
return
data
return
data
def
flush
(
self
):
def
flush
(
self
):
...
@@ -37,14 +37,14 @@ class Socket(object):
...
@@ -37,14 +37,14 @@ class Socket(object):
try
:
try
:
self
.
_socket
.
recv
(
0
)
self
.
_socket
.
recv
(
0
)
return
True
return
True
except
socket
.
error
,
(
err
,
_
)
:
except
socket
.
error
as
e
:
if
e
rr
!=
errno
.
EAGAIN
:
if
e
.
errno
!=
errno
.
EAGAIN
:
raise
raise
self
.
_socket
.
setblocking
(
1
)
self
.
_socket
.
setblocking
(
1
)
return
False
return
False
class
Console
(
object
)
:
class
Console
:
def
__init__
(
self
,
path
,
pdb
):
def
__init__
(
self
,
path
,
pdb
):
self
.
path
=
path
self
.
path
=
path
...
@@ -52,7 +52,7 @@ class Console(object):
...
@@ -52,7 +52,7 @@ class Console(object):
socket
.
SOCK_STREAM
|
socket
.
SOCK_CLOEXEC
)
socket
.
SOCK_STREAM
|
socket
.
SOCK_CLOEXEC
)
try
:
try
:
self
.
_removeSocket
()
self
.
_removeSocket
()
except
OSError
,
e
:
except
OSError
as
e
:
if
e
.
errno
!=
errno
.
ENOENT
:
if
e
.
errno
!=
errno
.
ENOENT
:
raise
raise
s
.
bind
(
path
)
s
.
bind
(
path
)
...
...
re6st/multicast.py
View file @
e80ca878
...
@@ -43,7 +43,7 @@ freeifaddrs = libc.freeifaddrs
...
@@ -43,7 +43,7 @@ freeifaddrs = libc.freeifaddrs
freeifaddrs
.
restype
=
None
freeifaddrs
.
restype
=
None
freeifaddrs
.
argtypes
=
[
POINTER
(
struct_ifaddrs
)]
freeifaddrs
.
argtypes
=
[
POINTER
(
struct_ifaddrs
)]
class
unpacker
(
object
)
:
class
unpacker
:
def
__init__
(
self
,
buf
):
def
__init__
(
self
,
buf
):
self
.
_buf
=
buf
self
.
_buf
=
buf
...
@@ -55,7 +55,7 @@ class unpacker(object):
...
@@ -55,7 +55,7 @@ class unpacker(object):
self
.
_offset
+=
s
.
size
self
.
_offset
+=
s
.
size
return
result
return
result
class
PimDm
(
object
)
:
class
PimDm
:
def
__init__
(
self
):
def
__init__
(
self
):
s_netlink
=
socket
(
AF_NETLINK
,
SOCK_RAW
,
NETLINK_ROUTE
)
s_netlink
=
socket
(
AF_NETLINK
,
SOCK_RAW
,
NETLINK_ROUTE
)
...
...
re6st/ovpn-client
View file @
e80ca878
#!/usr/bin/
python2
-S
#!/usr/bin/
env -S python3
-S
import
os
,
sys
import
os
,
sys
script_type
=
os
.
environ
[
'script_type'
]
script_type
=
os
.
environ
[
'script_type'
]
...
@@ -14,4 +14,5 @@ if script_type == 'up':
...
@@ -14,4 +14,5 @@ if script_type == 'up':
if
script_type
==
'route-up'
:
if
script_type
==
'route-up'
:
import
time
import
time
os
.
write
(
int
(
sys
.
argv
[
1
]),
repr
((
os
.
environ
[
'common_name'
],
time
.
time
(),
os
.
write
(
int
(
sys
.
argv
[
1
]),
repr
((
os
.
environ
[
'common_name'
],
time
.
time
(),
int
(
os
.
environ
[
'tls_serial_0'
]),
os
.
environ
[
'OPENVPN_external_ip'
])))
int
(
os
.
environ
[
'tls_serial_0'
]),
os
.
environ
[
'OPENVPN_external_ip'
]))
.
encode
())
re6st/ovpn-server
View file @
e80ca878
#!/usr/bin/
python2
-S
#!/usr/bin/
env -S python3
-S
import
os
,
sys
import
os
,
sys
script_type
=
os
.
environ
[
'script_type'
]
script_type
=
os
.
environ
[
'script_type'
]
...
@@ -7,10 +7,11 @@ external_ip = os.getenv('trusted_ip') or os.environ['trusted_ip6']
...
@@ -7,10 +7,11 @@ external_ip = os.getenv('trusted_ip') or os.environ['trusted_ip6']
# Write into pipe connect/disconnect events
# Write into pipe connect/disconnect events
fd
=
int
(
sys
.
argv
[
1
])
fd
=
int
(
sys
.
argv
[
1
])
os
.
write
(
fd
,
repr
((
script_type
,
(
os
.
environ
[
'common_name'
],
os
.
environ
[
'dev'
],
os
.
write
(
fd
,
repr
((
script_type
,
(
os
.
environ
[
'common_name'
],
os
.
environ
[
'dev'
],
int
(
os
.
environ
[
'tls_serial_0'
]),
external_ip
))))
int
(
os
.
environ
[
'tls_serial_0'
]),
external_ip
)))
.
encode
(
"utf-8"
))
if
script_type
==
'client-connect'
:
if
script_type
==
'client-connect'
:
if
os
.
read
(
fd
,
1
)
==
'
\
0
'
:
if
os
.
read
(
fd
,
1
)
==
b
'
\
0
'
:
sys
.
exit
(
1
)
sys
.
exit
(
1
)
# Send client its external ip address
# Send client its external ip address
with
open
(
sys
.
argv
[
2
],
'w'
)
as
f
:
with
open
(
sys
.
argv
[
2
],
'w'
)
as
f
:
...
...
re6st/plib.py
View file @
e80ca878
import
binascii
import
logging
,
errno
,
os
import
logging
,
errno
,
os
from
typing
import
Optional
from
.
import
utils
from
.
import
utils
here
=
os
.
path
.
realpath
(
os
.
path
.
dirname
(
__file__
))
here
=
os
.
path
.
realpath
(
os
.
path
.
dirname
(
__file__
))
ovpn_server
=
os
.
path
.
join
(
here
,
'ovpn-server'
)
ovpn_server
=
os
.
path
.
join
(
here
,
'ovpn-server'
)
ovpn_client
=
os
.
path
.
join
(
here
,
'ovpn-client'
)
ovpn_client
=
os
.
path
.
join
(
here
,
'ovpn-client'
)
ovpn_log
=
None
ovpn_log
:
Optional
[
str
]
=
None
def
openvpn
(
iface
,
encrypt
,
*
args
,
**
kw
)
:
def
openvpn
(
iface
:
str
,
encrypt
,
*
args
,
**
kw
)
->
utils
.
Popen
:
args
=
[
'openvpn'
,
args
=
[
'openvpn'
,
'--dev-type'
,
'tap'
,
'--dev-type'
,
'tap'
,
'--dev'
,
iface
,
'--dev'
,
iface
,
...
@@ -19,13 +21,16 @@ def openvpn(iface, encrypt, *args, **kw):
...
@@ -19,13 +21,16 @@ def openvpn(iface, encrypt, *args, **kw):
if
ovpn_log
:
if
ovpn_log
:
args
+=
'--log-append'
,
os
.
path
.
join
(
ovpn_log
,
'%s.log'
%
iface
),
args
+=
'--log-append'
,
os
.
path
.
join
(
ovpn_log
,
'%s.log'
%
iface
),
if
not
encrypt
:
if
not
encrypt
:
# TODO: --ncp-disable was deprecated in OpenVPN 2.5 and removed in 2.6
# and is no longer necessary in those versions.
args
+=
'--cipher'
,
'none'
,
'--ncp-disable'
args
+=
'--cipher'
,
'none'
,
'--ncp-disable'
logging
.
debug
(
'%r'
,
args
)
logging
.
debug
(
'%r'
,
args
)
return
utils
.
Popen
(
args
,
**
kw
)
return
utils
.
Popen
(
args
,
**
kw
)
ovpn_link_mtu_dict
=
{
'udp4'
:
1432
,
'udp6'
:
1450
}
ovpn_link_mtu_dict
=
{
'udp4'
:
1432
,
'udp6'
:
1450
}
def
server
(
iface
,
max_clients
,
dh_path
,
fd
,
port
,
proto
,
encrypt
,
*
args
,
**
kw
):
def
server
(
iface
:
str
,
max_clients
:
int
,
dh_path
:
str
,
fd
:
int
,
port
:
int
,
proto
:
str
,
encrypt
:
bool
,
*
args
,
**
kw
)
->
utils
.
Popen
:
if
proto
==
'udp'
:
if
proto
==
'udp'
:
proto
=
'udp4'
proto
=
'udp4'
client_script
=
'%s %s'
%
(
ovpn_server
,
fd
)
client_script
=
'%s %s'
%
(
ovpn_server
,
fd
)
...
@@ -43,10 +48,11 @@ def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
...
@@ -43,10 +48,11 @@ def server(iface, max_clients, dh_path, fd, port, proto, encrypt, *args, **kw):
'--max-clients'
,
str
(
max_clients
),
'--max-clients'
,
str
(
max_clients
),
'--port'
,
str
(
port
),
'--port'
,
str
(
port
),
'--proto'
,
proto
,
'--proto'
,
proto
,
*
args
,
**
kw
)
*
args
,
pass_fds
=
[
fd
],
**
kw
)
def
client
(
iface
,
address_list
,
encrypt
,
*
args
,
**
kw
):
def
client
(
iface
:
str
,
address_list
:
list
[
tuple
[
str
,
int
,
str
]],
encrypt
:
bool
,
*
args
,
**
kw
)
->
utils
.
Popen
:
remote
=
[
'--nobind'
,
'--client'
]
remote
=
[
'--nobind'
,
'--client'
]
# XXX: We'd like to pass <connection> sections at command-line.
# XXX: We'd like to pass <connection> sections at command-line.
link_mtu
=
set
()
link_mtu
=
set
()
...
@@ -62,8 +68,10 @@ def client(iface, address_list, encrypt, *args, **kw):
...
@@ -62,8 +68,10 @@ def client(iface, address_list, encrypt, *args, **kw):
return
openvpn
(
iface
,
encrypt
,
*
remote
,
**
kw
)
return
openvpn
(
iface
,
encrypt
,
*
remote
,
**
kw
)
def
router
(
ip
,
ip4
,
rt6
,
hello_interval
,
log_path
,
state_path
,
pidfile
,
def
router
(
ip
:
tuple
[
str
,
int
],
ip4
,
rt6
:
tuple
[
str
,
bool
,
bool
],
control_socket
,
default
,
hmac
,
*
args
,
**
kw
):
hello_interval
:
int
,
log_path
:
str
,
state_path
:
str
,
pidfile
:
str
,
control_socket
:
str
,
default
:
str
,
hmac
:
tuple
[
bytes
|
None
,
bytes
|
None
],
*
args
,
**
kw
)
->
utils
.
Popen
:
network
,
gateway
,
has_ipv6_subtrees
=
rt6
network
,
gateway
,
has_ipv6_subtrees
=
rt6
network_mask
=
int
(
network
[
network
.
index
(
'/'
)
+
1
:])
network_mask
=
int
(
network
[
network
.
index
(
'/'
)
+
1
:])
ip
,
n
=
ip
ip
,
n
=
ip
...
@@ -80,9 +88,9 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
...
@@ -80,9 +88,9 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
'-C'
,
'redistribute local deny'
,
'-C'
,
'redistribute local deny'
,
'-C'
,
'redistribute ip %s/%s eq %s'
%
(
ip
,
n
,
n
)]
'-C'
,
'redistribute ip %s/%s eq %s'
%
(
ip
,
n
,
n
)]
if
hmac_sign
:
if
hmac_sign
:
def
key
(
cmd
,
id
,
value
):
def
key
(
cmd
:
list
[
str
],
id
:
str
,
value
:
bytes
):
cmd
+=
'-C'
,
(
'key type blake2s128 id %s value %s'
%
cmd
+=
'-C'
,
(
'key type blake2s128 id %s value %s'
%
(
id
,
value
.
encode
(
'hex'
)))
(
id
,
binascii
.
hexlify
(
value
).
decode
(
)))
key
(
cmd
,
'sign'
,
hmac_sign
)
key
(
cmd
,
'sign'
,
hmac_sign
)
default
+=
' key sign'
default
+=
' key sign'
if
hmac_accept
is
not
None
:
if
hmac_accept
is
not
None
:
...
@@ -132,7 +140,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
...
@@ -132,7 +140,7 @@ def router(ip, ip4, rt6, hello_interval, log_path, state_path, pidfile,
# WKRD: babeld fails to start if pidfile already exists
# WKRD: babeld fails to start if pidfile already exists
try
:
try
:
os
.
remove
(
pidfile
)
os
.
remove
(
pidfile
)
except
OSError
,
e
:
except
OSError
as
e
:
if
e
.
errno
!=
errno
.
ENOENT
:
if
e
.
errno
!=
errno
.
ENOENT
:
raise
raise
logging
.
info
(
'%r'
,
cmd
)
logging
.
info
(
'%r'
,
cmd
)
...
...
re6st/registry.py
View file @
e80ca878
...
@@ -18,16 +18,19 @@ Authenticated communication:
...
@@ -18,16 +18,19 @@ Authenticated communication:
- the last one that was really used by the client (!hello)
- the last one that was really used by the client (!hello)
- the one of the last handshake (hello)
- the one of the last handshake (hello)
"""
"""
import
base64
,
hmac
,
hashlib
,
http
lib
,
inspect
,
json
,
logging
import
base64
,
hmac
,
hashlib
,
http
.
client
,
inspect
,
json
,
logging
import
mailbox
,
os
,
platform
,
random
,
select
,
smtplib
,
socket
,
sqlite3
import
mailbox
,
os
,
platform
,
random
,
select
,
smtplib
,
socket
,
sqlite3
import
string
,
sys
,
threading
,
time
,
weakref
,
zlib
import
string
,
sys
,
threading
,
time
,
weakref
,
zlib
from
collections
import
defaultdict
,
deque
from
collections
import
defaultdict
,
deque
from
collections.abc
import
Iterator
from
datetime
import
datetime
from
datetime
import
datetime
from
BaseHTTPS
erver
import
HTTPServer
,
BaseHTTPRequestHandler
from
http.s
erver
import
HTTPServer
,
BaseHTTPRequestHandler
from
email.mime.text
import
MIMEText
from
email.mime.text
import
MIMEText
from
operator
import
itemgetter
from
operator
import
itemgetter
from
typing
import
Tuple
from
OpenSSL
import
crypto
from
OpenSSL
import
crypto
from
urllib
import
splittype
,
splithost
,
unquote
,
urlencode
from
urllib
.parse
import
urlparse
,
unquote
,
urlencode
from
.
import
ctl
,
tunnel
,
utils
,
version
,
x509
from
.
import
ctl
,
tunnel
,
utils
,
version
,
x509
HMAC_HEADER
=
"Re6stHMAC"
HMAC_HEADER
=
"Re6stHMAC"
...
@@ -35,13 +38,13 @@ RENEW_PERIOD = 30 * 86400
...
@@ -35,13 +38,13 @@ RENEW_PERIOD = 30 * 86400
BABEL_HMAC
=
'babel_hmac0'
,
'babel_hmac1'
,
'babel_hmac2'
BABEL_HMAC
=
'babel_hmac0'
,
'babel_hmac1'
,
'babel_hmac2'
def
rpc
(
f
):
def
rpc
(
f
):
args
,
varargs
,
varkw
,
defaults
=
inspect
.
get
argspec
(
f
)
args
pec
=
inspect
.
getfull
argspec
(
f
)
assert
not
(
varargs
or
varkw
),
f
assert
not
(
argspec
.
varargs
or
argspec
.
varkw
),
f
if
not
defaults
:
sig
=
inspect
.
signature
(
f
)
defaults
=
(
)
sig
=
sig
.
replace
(
parameters
=
[
v
.
replace
(
annotation
=
inspect
.
Parameter
.
empty
)
i
=
len
(
args
)
-
len
(
defaults
)
for
v
in
sig
.
parameters
.
values
()][
1
:],
f
.
getcallargs
=
eval
(
"lambda %s: locals()"
%
','
.
join
(
args
[
1
:
i
]
return_annotation
=
inspect
.
Signature
.
empty
)
+
map
(
"%s=%r"
.
__mod__
,
zip
(
args
[
i
:],
defaults
)))
)
f
.
getcallargs
=
eval
(
"lambda %s: locals()"
%
str
(
sig
)[
1
:
-
1
]
)
return
f
return
f
def
rpc_private
(
f
):
def
rpc_private
(
f
):
...
@@ -53,13 +56,15 @@ class HTTPError(Exception):
...
@@ -53,13 +56,15 @@ class HTTPError(Exception):
pass
pass
class
RegistryServer
(
object
)
:
class
RegistryServer
:
peers
=
0
,
()
peers
=
0
,
()
cert_duration
=
365
*
86400
cert_duration
=
365
*
86400
sessions
:
dict
[
str
,
list
[
tuple
[
bytes
,
int
]]]
def
_geoiplookup
(
self
,
ip
):
def
_geoiplookup
(
self
,
ip
):
raise
HTTPError
(
http
lib
.
BAD_REQUEST
)
raise
HTTPError
(
http
.
client
.
BAD_REQUEST
)
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
config
=
config
self
.
config
=
config
...
@@ -69,14 +74,14 @@ class RegistryServer(object):
...
@@ -69,14 +74,14 @@ class RegistryServer(object):
# Parse community file
# Parse community file
self
.
community_map
=
{}
self
.
community_map
=
{}
if
config
.
community
:
if
config
.
community
:
with
open
(
config
.
community
)
as
x
:
with
open
(
config
.
community
)
as
x
:
for
x
in
x
:
for
x
in
x
:
x
=
x
.
strip
()
x
=
x
.
strip
()
if
x
and
not
x
.
startswith
(
'#'
):
if
x
and
not
x
.
startswith
(
'#'
):
x
=
x
.
split
()
x
=
x
.
split
()
self
.
community_map
[
x
.
pop
(
0
)]
=
x
self
.
community_map
[
x
.
pop
(
0
)]
=
x
if
sum
(
'*'
in
x
for
x
in
self
.
community_map
.
iter
values
())
!=
1
:
if
sum
(
'*'
in
x
for
x
in
self
.
community_map
.
values
())
!=
1
:
sys
.
exit
(
"Invalid community configuration: missing or multiple default location ('*')"
)
sys
.
exit
(
"Invalid community configuration: missing or multiple default location ('*')"
)
else
:
else
:
self
.
community_map
[
''
]
=
'*'
self
.
community_map
[
''
]
=
'*'
...
@@ -91,7 +96,7 @@ class RegistryServer(object):
...
@@ -91,7 +96,7 @@ class RegistryServer(object):
"name TEXT PRIMARY KEY NOT NULL"
,
"name TEXT PRIMARY KEY NOT NULL"
,
"value"
)
"value"
)
self
.
prefix
=
self
.
getConfig
(
"prefix"
,
None
)
self
.
prefix
=
self
.
getConfig
(
"prefix"
,
None
)
self
.
version
=
s
tr
(
self
.
getConfig
(
"version"
,
"
\
0
"
))
# BBB: blob
self
.
version
=
s
elf
.
getConfig
(
"version"
,
b'
\
0
'
)
utils
.
sqliteCreateTable
(
self
.
db
,
"token"
,
utils
.
sqliteCreateTable
(
self
.
db
,
"token"
,
"token TEXT PRIMARY KEY NOT NULL"
,
"token TEXT PRIMARY KEY NOT NULL"
,
"email TEXT NOT NULL"
,
"email TEXT NOT NULL"
,
...
@@ -102,6 +107,7 @@ class RegistryServer(object):
...
@@ -102,6 +107,7 @@ class RegistryServer(object):
"email TEXT"
,
"email TEXT"
,
"cert TEXT"
)
"cert TEXT"
)
if
not
self
.
db
.
execute
(
"SELECT 1 FROM cert LIMIT 1"
).
fetchone
():
if
not
self
.
db
.
execute
(
"SELECT 1 FROM cert LIMIT 1"
).
fetchone
():
logging
.
debug
(
"No existing certs found; creating an unallocated cert"
)
self
.
db
.
execute
(
"INSERT INTO cert VALUES ('',null,null)"
)
self
.
db
.
execute
(
"INSERT INTO cert VALUES ('',null,null)"
)
prev
=
'-'
prev
=
'-'
...
@@ -139,10 +145,10 @@ class RegistryServer(object):
...
@@ -139,10 +145,10 @@ class RegistryServer(object):
if
self
.
geoip_db
:
if
self
.
geoip_db
:
from
geoip2
import
database
,
errors
from
geoip2
import
database
,
errors
country
=
database
.
Reader
(
self
.
geoip_db
).
country
country
=
database
.
Reader
(
self
.
geoip_db
).
country
def
geoiplookup
(
ip
)
:
def
geoiplookup
(
ip
:
str
)
->
Tuple
[
str
,
str
]
:
try
:
try
:
req
=
country
(
ip
)
req
=
country
(
ip
)
return
req
.
country
.
iso_code
.
encode
(),
req
.
continent
.
code
.
encode
()
return
req
.
country
.
iso_code
,
req
.
continent
.
code
except
(
errors
.
AddressNotFoundError
,
ValueError
):
except
(
errors
.
AddressNotFoundError
,
ValueError
):
return
'*'
,
'*'
return
'*'
,
'*'
self
.
_geoiplookup
=
geoiplookup
self
.
_geoiplookup
=
geoiplookup
...
@@ -157,6 +163,11 @@ class RegistryServer(object):
...
@@ -157,6 +163,11 @@ class RegistryServer(object):
else
:
else
:
self
.
newHMAC
(
0
)
self
.
newHMAC
(
0
)
def
close
(
self
):
self
.
sock
.
close
()
self
.
db
.
close
()
self
.
ctl
.
close
()
def
getConfig
(
self
,
name
,
*
default
):
def
getConfig
(
self
,
name
,
*
default
):
r
,
=
next
(
self
.
db
.
execute
(
r
,
=
next
(
self
.
db
.
execute
(
"SELECT value FROM config WHERE name=?"
,
(
name
,)),
default
)
"SELECT value FROM config WHERE name=?"
,
(
name
,)),
default
)
...
@@ -169,8 +180,8 @@ class RegistryServer(object):
...
@@ -169,8 +180,8 @@ class RegistryServer(object):
def
updateNetworkConfig
(
self
,
_it0
=
itemgetter
(
0
)):
def
updateNetworkConfig
(
self
,
_it0
=
itemgetter
(
0
)):
kw
=
{
kw
=
{
'babel_default'
:
'max-rtt-penalty 5000 rtt-max 500 rtt-decay 125'
,
'babel_default'
:
'max-rtt-penalty 5000 rtt-max 500 rtt-decay 125'
,
'crl'
:
map
(
_it0
,
self
.
db
.
execute
(
'crl'
:
list
(
map
(
_it0
,
self
.
db
.
execute
(
"SELECT serial FROM crl ORDER BY serial"
)),
"SELECT serial FROM crl ORDER BY serial"
))
)
,
'protocol'
:
version
.
protocol
,
'protocol'
:
version
.
protocol
,
'registry_prefix'
:
self
.
prefix
,
'registry_prefix'
:
self
.
prefix
,
}
}
...
@@ -184,32 +195,32 @@ class RegistryServer(object):
...
@@ -184,32 +195,32 @@ class RegistryServer(object):
config
=
json
.
dumps
(
kw
,
sort_keys
=
True
)
config
=
json
.
dumps
(
kw
,
sort_keys
=
True
)
if
config
!=
self
.
getConfig
(
'last_config'
,
None
):
if
config
!=
self
.
getConfig
(
'last_config'
,
None
):
self
.
increaseVersion
()
self
.
increaseVersion
()
# BBB: Use buffer because of http://bugs.python.org/issue13676
self
.
setConfig
(
'version'
,
self
.
version
)
# on Python 2.6
self
.
setConfig
(
'version'
,
buffer
(
self
.
version
))
self
.
setConfig
(
'last_config'
,
config
)
self
.
setConfig
(
'last_config'
,
config
)
self
.
sendto
(
self
.
prefix
,
0
)
self
.
sendto
(
self
.
prefix
,
0
)
# The following entry lists values that are base64-encoded.
# The following entry lists values that are base64-encoded.
kw
[
''
]
=
'version'
,
kw
[
''
]
=
'version'
,
kw
[
'version'
]
=
self
.
version
.
encode
(
'base64'
)
kw
[
'version'
]
=
base64
.
b64encode
(
self
.
version
).
decode
(
)
self
.
network_config
=
kw
self
.
network_config
=
kw
def
increaseVersion
(
self
):
def
increaseVersion
(
self
):
x
=
utils
.
packInteger
(
1
+
utils
.
unpackInteger
(
self
.
version
)[
0
])
x
=
utils
.
packInteger
(
1
+
utils
.
unpackInteger
(
self
.
version
)[
0
])
self
.
version
=
x
+
self
.
cert
.
sign
(
x
)
self
.
version
=
x
+
self
.
cert
.
sign
(
x
)
def
sendto
(
self
,
prefix
,
code
):
def
sendto
(
self
,
prefix
:
str
,
code
:
int
):
self
.
sock
.
sendto
(
"%s
\
0
%c"
%
(
prefix
,
code
),
(
'::1'
,
tunnel
.
PORT
))
self
.
sock
.
sendto
(
prefix
.
encode
()
+
bytes
((
0
,
code
)),
(
'::1'
,
tunnel
.
PORT
))
def
recv
(
self
,
code
)
:
def
recv
(
self
,
code
:
int
)
->
tuple
[
str
,
str
]
|
tuple
[
None
,
None
]
:
try
:
try
:
prefix
,
msg
=
self
.
sock
.
recv
(
1
<<
16
).
split
(
'
\
0
'
,
1
)
prefix
,
msg
=
self
.
sock
.
recv
(
1
<<
16
).
split
(
b
'
\
0
'
,
1
)
int
(
prefix
,
2
)
int
(
prefix
,
2
)
except
ValueError
:
except
ValueError
:
pass
pass
else
:
else
:
if
msg
and
ord
(
msg
[
0
])
==
code
:
if
len
(
msg
)
>=
1
and
msg
[
0
]
==
code
:
return
prefix
,
msg
[
1
:]
return
prefix
.
decode
(),
msg
[
1
:].
decode
()
logging
.
error
(
"Invalid message or unexpected code: %r"
,
msg
)
return
None
,
None
return
None
,
None
def
select
(
self
,
r
,
w
,
t
):
def
select
(
self
,
r
,
w
,
t
):
...
@@ -235,7 +246,7 @@ class RegistryServer(object):
...
@@ -235,7 +246,7 @@ class RegistryServer(object):
def
babel_dump
(
self
):
def
babel_dump
(
self
):
self
.
_wait_dump
=
False
self
.
_wait_dump
=
False
def
iterCert
(
self
):
def
iterCert
(
self
)
->
Iterator
[
Tuple
[
crypto
.
X509
,
str
,
str
]]
:
for
prefix
,
email
,
cert
in
self
.
db
.
execute
(
for
prefix
,
email
,
cert
in
self
.
db
.
execute
(
"SELECT * FROM cert WHERE cert IS NOT NULL"
):
"SELECT * FROM cert WHERE cert IS NOT NULL"
):
try
:
try
:
...
@@ -263,8 +274,10 @@ class RegistryServer(object):
...
@@ -263,8 +274,10 @@ class RegistryServer(object):
x
=
x509
.
notAfter
(
cert
)
x
=
x509
.
notAfter
(
cert
)
if
x
<=
old
:
if
x
<=
old
:
if
prefix
==
self
.
prefix
:
if
prefix
==
self
.
prefix
:
logging
.
critical
(
"Refuse to delete certificate"
logging
.
critical
(
" of main node: wrong clock ?"
)
"Refuse to delete certificate of main node:"
" wrong clock ? Alternatively, the database"
" might be in an inconsistent state."
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
logging
.
info
(
"Delete %s: %s (invalid since %s)"
,
logging
.
info
(
"Delete %s: %s (invalid since %s)"
,
"certificate requested by '%s'"
%
email
"certificate requested by '%s'"
%
email
...
@@ -286,14 +299,15 @@ class RegistryServer(object):
...
@@ -286,14 +299,15 @@ class RegistryServer(object):
x_forwarded_for
=
request
.
headers
.
get
(
'X-Forwarded-For'
)
x_forwarded_for
=
request
.
headers
.
get
(
'X-Forwarded-For'
)
if
request
.
client_address
[
0
]
not
in
authorized_origin
or
\
if
request
.
client_address
[
0
]
not
in
authorized_origin
or
\
x_forwarded_for
and
x_forwarded_for
not
in
authorized_origin
:
x_forwarded_for
and
x_forwarded_for
not
in
authorized_origin
:
return
request
.
send_error
(
http
lib
.
FORBIDDEN
)
return
request
.
send_error
(
http
.
client
.
FORBIDDEN
)
key
=
m
.
getcallargs
(
**
kw
).
get
(
'cn'
)
key
=
m
.
getcallargs
(
**
kw
).
get
(
'cn'
)
if
key
:
if
key
:
h
=
base64
.
b64decode
(
request
.
headers
[
HMAC_HEADER
])
h
=
base64
.
b64decode
(
request
.
headers
[
HMAC_HEADER
])
with
self
.
lock
:
with
self
.
lock
:
session
=
self
.
sessions
[
key
]
session
=
self
.
sessions
[
key
]
for
key
,
protocol
in
session
:
for
key
,
protocol
in
session
:
if
h
==
hmac
.
HMAC
(
key
,
request
.
path
,
hashlib
.
sha1
).
digest
():
if
h
==
hmac
.
HMAC
(
key
,
request
.
path
.
encode
(),
hashlib
.
sha1
).
digest
():
break
break
else
:
else
:
raise
Exception
(
"Wrong HMAC"
)
raise
Exception
(
"Wrong HMAC"
)
...
@@ -313,29 +327,31 @@ class RegistryServer(object):
...
@@ -313,29 +327,31 @@ class RegistryServer(object):
request
.
headers
.
get
(
"host"
))
request
.
headers
.
get
(
"host"
))
try
:
try
:
result
=
m
(
**
kw
)
result
=
m
(
**
kw
)
except
HTTPError
,
e
:
except
HTTPError
as
e
:
return
request
.
send_error
(
*
e
.
args
)
return
request
.
send_error
(
*
e
.
args
)
except
:
except
:
logging
.
warning
(
request
.
requestline
,
exc_info
=
1
)
logging
.
warning
(
request
.
requestline
,
exc_info
=
True
)
return
request
.
send_error
(
http
lib
.
INTERNAL_SERVER_ERROR
)
return
request
.
send_error
(
http
.
client
.
INTERNAL_SERVER_ERROR
)
if
result
:
if
result
:
request
.
send_response
(
httplib
.
OK
)
if
type
(
result
)
is
str
:
result
=
result
.
encode
(
"utf-8"
)
request
.
send_response
(
http
.
client
.
OK
)
request
.
send_header
(
"Content-Length"
,
str
(
len
(
result
)))
request
.
send_header
(
"Content-Length"
,
str
(
len
(
result
)))
else
:
else
:
request
.
send_response
(
http
lib
.
NO_CONTENT
)
request
.
send_response
(
http
.
client
.
NO_CONTENT
)
if
key
:
if
key
:
request
.
send_header
(
HMAC_HEADER
,
base64
.
b64encode
(
request
.
send_header
(
HMAC_HEADER
,
base64
.
b64encode
(
hmac
.
HMAC
(
key
,
result
,
hashlib
.
sha1
).
digest
()))
hmac
.
HMAC
(
key
,
result
,
hashlib
.
sha1
).
digest
())
.
decode
(
"ascii"
)
)
request
.
end_headers
()
request
.
end_headers
()
if
result
:
if
result
:
request
.
wfile
.
write
(
result
)
request
.
wfile
.
write
(
result
)
def
getPeerProtocol
(
self
,
cn
)
:
def
getPeerProtocol
(
self
,
cn
:
str
)
->
int
:
session
,
=
self
.
sessions
[
cn
]
session
,
=
self
.
sessions
[
cn
]
return
session
[
1
]
return
session
[
1
]
@
rpc
@
rpc
def
hello
(
self
,
client_prefix
,
protocol
=
'1'
)
:
def
hello
(
self
,
client_prefix
:
str
,
protocol
=
'1'
)
->
bytes
:
with
self
.
lock
:
with
self
.
lock
:
cert
=
self
.
getCert
(
client_prefix
)
cert
=
self
.
getCert
(
client_prefix
)
key
=
utils
.
newHmacSecret
()
key
=
utils
.
newHmacSecret
()
...
@@ -345,29 +361,32 @@ class RegistryServer(object):
...
@@ -345,29 +361,32 @@ class RegistryServer(object):
assert
len
(
key
)
==
len
(
sign
)
assert
len
(
key
)
==
len
(
sign
)
return
key
+
sign
return
key
+
sign
def
getCert
(
self
,
client_prefix
)
:
def
getCert
(
self
,
client_prefix
:
str
)
->
bytes
:
assert
self
.
lock
.
locked
()
assert
self
.
lock
.
locked
()
return
self
.
db
.
execute
(
"SELECT cert FROM cert"
cert
=
self
.
db
.
execute
(
"SELECT cert FROM cert"
" WHERE prefix=? AND cert IS NOT NULL"
,
" WHERE prefix=? AND cert IS NOT NULL"
,
(
client_prefix
,)).
next
()[
0
]
(
client_prefix
,)).
fetchone
()
assert
cert
,
(
f"No cert result for prefix '
{
client_prefix
}
';"
f" this should not happen, DB is inconsistent"
)
return
cert
[
0
]
@
rpc_private
@
rpc_private
def
isToken
(
self
,
token
):
def
isToken
(
self
,
token
:
str
):
with
self
.
lock
:
with
self
.
lock
:
if
self
.
db
.
execute
(
"SELECT 1 FROM token WHERE token = ?"
,
if
self
.
db
.
execute
(
"SELECT 1 FROM token WHERE token = ?"
,
(
token
,)).
fetchone
():
(
token
,)).
fetchone
():
return
"1"
return
b
"1"
@
rpc_private
@
rpc_private
def
deleteToken
(
self
,
token
):
def
deleteToken
(
self
,
token
:
str
):
with
self
.
lock
:
with
self
.
lock
:
self
.
db
.
execute
(
"DELETE FROM token WHERE token = ?"
,
(
token
,))
self
.
db
.
execute
(
"DELETE FROM token WHERE token = ?"
,
(
token
,))
@
rpc_private
@
rpc_private
def
addToken
(
self
,
email
,
token
)
:
def
addToken
(
self
,
email
:
str
,
token
:
str
|
None
)
->
str
:
prefix_len
=
self
.
config
.
prefix_length
prefix_len
=
self
.
config
.
prefix_length
if
not
prefix_len
:
if
not
prefix_len
:
raise
HTTPError
(
http
lib
.
FORBIDDEN
)
raise
HTTPError
(
http
.
client
.
FORBIDDEN
)
request
=
token
is
None
request
=
token
is
None
with
self
.
lock
:
with
self
.
lock
:
while
True
:
while
True
:
...
@@ -381,7 +400,7 @@ class RegistryServer(object):
...
@@ -381,7 +400,7 @@ class RegistryServer(object):
break
break
except
sqlite3
.
IntegrityError
:
except
sqlite3
.
IntegrityError
:
if
not
request
:
if
not
request
:
raise
HTTPError
(
http
lib
.
CONFLICT
)
raise
HTTPError
(
http
.
client
.
CONFLICT
)
self
.
timeout
=
1
self
.
timeout
=
1
if
request
:
if
request
:
return
token
return
token
...
@@ -389,7 +408,7 @@ class RegistryServer(object):
...
@@ -389,7 +408,7 @@ class RegistryServer(object):
@
rpc
@
rpc
def
requestToken
(
self
,
email
):
def
requestToken
(
self
,
email
):
if
not
self
.
config
.
mailhost
:
if
not
self
.
config
.
mailhost
:
raise
HTTPError
(
http
lib
.
FORBIDDEN
)
raise
HTTPError
(
http
.
client
.
FORBIDDEN
)
token
=
self
.
addToken
(
email
,
None
)
token
=
self
.
addToken
(
email
,
None
)
...
@@ -418,11 +437,11 @@ class RegistryServer(object):
...
@@ -418,11 +437,11 @@ class RegistryServer(object):
s
.
quit
()
s
.
quit
()
def
getCommunity
(
self
,
country
,
continent
):
def
getCommunity
(
self
,
country
,
continent
):
for
prefix
,
location_list
in
self
.
community_map
.
ite
rite
ms
():
for
prefix
,
location_list
in
self
.
community_map
.
items
():
if
country
in
location_list
:
if
country
in
location_list
:
return
prefix
return
prefix
default
=
''
default
=
''
for
prefix
,
location_list
in
self
.
community_map
.
ite
rite
ms
():
for
prefix
,
location_list
in
self
.
community_map
.
items
():
if
continent
in
location_list
:
if
continent
in
location_list
:
return
prefix
return
prefix
if
'*'
in
location_list
:
if
'*'
in
location_list
:
...
@@ -430,13 +449,14 @@ class RegistryServer(object):
...
@@ -430,13 +449,14 @@ class RegistryServer(object):
return
default
return
default
def
mergePrefixes
(
self
):
def
mergePrefixes
(
self
):
logging
.
debug
(
"Merging prefixes"
)
q
=
self
.
db
.
execute
q
=
self
.
db
.
execute
prev_prefix
=
None
prev_prefix
=
None
max_len
=
128
,
max_len
=
128
,
while
True
:
while
True
:
max_len
=
q
(
"SELECT max(length(prefix)) FROM cert"
max_len
=
q
(
"SELECT max(length(prefix)) FROM cert"
" WHERE cert is null AND length(prefix) < ?"
,
" WHERE cert is null AND length(prefix) < ?"
,
max_len
).
next
()
max_len
).
fetchone
()
if
not
max_len
[
0
]:
if
not
max_len
[
0
]:
break
break
for
prefix
,
in
q
(
"SELECT prefix FROM cert"
for
prefix
,
in
q
(
"SELECT prefix FROM cert"
...
@@ -452,6 +472,7 @@ class RegistryServer(object):
...
@@ -452,6 +472,7 @@ class RegistryServer(object):
prev_prefix
=
prefix
prev_prefix
=
prefix
def
newPrefix
(
self
,
prefix_len
,
community
):
def
newPrefix
(
self
,
prefix_len
,
community
):
logging
.
info
(
"Allocating /%u prefix for %s"
,
prefix_len
,
community
)
community_len
=
len
(
community
)
community_len
=
len
(
community
)
prefix_len
+=
community_len
prefix_len
+=
community_len
max_len
=
128
-
len
(
self
.
network
)
max_len
=
128
-
len
(
self
.
network
)
...
@@ -460,25 +481,25 @@ class RegistryServer(object):
...
@@ -460,25 +481,25 @@ class RegistryServer(object):
while
True
:
while
True
:
try
:
try
:
# Find longest free prefix whithin community.
# Find longest free prefix whithin community.
prefix
,
=
q
(
prefix
,
=
next
(
q
(
"SELECT prefix FROM cert"
"SELECT prefix FROM cert"
" WHERE prefix LIKE ?"
" WHERE prefix LIKE ?"
" AND length(prefix) <= ? AND cert is null"
" AND length(prefix) <= ? AND cert is null"
" ORDER BY length(prefix) DESC"
,
" ORDER BY length(prefix) DESC"
,
(
community
+
'%'
,
prefix_len
))
.
next
(
)
(
community
+
'%'
,
prefix_len
)))
except
StopIteration
:
except
StopIteration
:
# Community not yet allocated?
# Community not yet allocated?
# There should be exactly 1 row whose
# There should be exactly 1 row whose
# prefix is the beginning of community.
# prefix is the beginning of community.
prefix
,
x
=
q
(
"SELECT prefix, cert FROM cert"
prefix
,
x
=
next
(
q
(
"SELECT prefix, cert FROM cert"
" WHERE substr(?,1,length(prefix)) = prefix"
,
" WHERE substr(?,1,length(prefix)) = prefix"
,
(
community
,))
.
next
(
)
(
community
,)))
if
x
is
not
None
:
if
x
is
not
None
:
logging
.
error
(
'No more free /%u prefix available'
,
logging
.
error
(
'No more free /%u prefix available'
,
prefix_len
)
prefix_len
)
raise
raise
# Split the tree until prefix has wanted length.
# Split the tree until prefix has wanted length.
for
x
in
x
range
(
len
(
prefix
),
prefix_len
):
for
x
in
range
(
len
(
prefix
),
prefix_len
):
# Prefix starts with community, then we complete with 0.
# Prefix starts with community, then we complete with 0.
x
=
community
[
x
]
if
x
<
community_len
else
'0'
x
=
community
[
x
]
if
x
<
community_len
else
'0'
q
(
"UPDATE cert SET prefix = ? WHERE prefix = ?"
,
q
(
"UPDATE cert SET prefix = ? WHERE prefix = ?"
,
...
@@ -490,17 +511,19 @@ class RegistryServer(object):
...
@@ -490,17 +511,19 @@ class RegistryServer(object):
q
(
"UPDATE cert SET cert = 'reserved' WHERE prefix = ?"
,
(
prefix
,))
q
(
"UPDATE cert SET cert = 'reserved' WHERE prefix = ?"
,
(
prefix
,))
@
rpc
@
rpc
def
requestCertificate
(
self
,
token
,
req
,
location
=
''
,
ip
=
''
):
def
requestCertificate
(
self
,
token
:
str
|
None
,
req
:
bytes
,
location
:
str
=
''
,
ip
:
str
=
''
):
logging
.
debug
(
"Requesting certificate with token %s"
,
token
)
req
=
crypto
.
load_certificate_request
(
crypto
.
FILETYPE_PEM
,
req
)
req
=
crypto
.
load_certificate_request
(
crypto
.
FILETYPE_PEM
,
req
)
with
self
.
lock
:
with
self
.
lock
:
with
self
.
db
:
with
self
.
db
:
if
token
:
if
token
:
if
not
self
.
config
.
prefix_length
:
if
not
self
.
config
.
prefix_length
:
raise
HTTPError
(
http
lib
.
FORBIDDEN
)
raise
HTTPError
(
http
.
client
.
FORBIDDEN
)
try
:
try
:
token
,
email
,
prefix_len
,
_
=
self
.
db
.
execute
(
token
,
email
,
prefix_len
,
_
=
next
(
self
.
db
.
execute
(
"SELECT * FROM token WHERE token = ?"
,
"SELECT * FROM token WHERE token = ?"
,
(
token
,))
.
next
(
)
(
token
,)))
except
StopIteration
:
except
StopIteration
:
return
return
self
.
db
.
execute
(
"DELETE FROM token WHERE token = ?"
,
self
.
db
.
execute
(
"DELETE FROM token WHERE token = ?"
,
...
@@ -508,7 +531,7 @@ class RegistryServer(object):
...
@@ -508,7 +531,7 @@ class RegistryServer(object):
else
:
else
:
prefix_len
=
self
.
config
.
anonymous_prefix_length
prefix_len
=
self
.
config
.
anonymous_prefix_length
if
not
prefix_len
:
if
not
prefix_len
:
raise
HTTPError
(
http
lib
.
FORBIDDEN
)
raise
HTTPError
(
http
.
client
.
FORBIDDEN
)
email
=
None
email
=
None
country
,
continent
=
'*'
,
'*'
country
,
continent
=
'*'
,
'*'
if
self
.
geoip_db
:
if
self
.
geoip_db
:
...
@@ -563,7 +586,7 @@ class RegistryServer(object):
...
@@ -563,7 +586,7 @@ class RegistryServer(object):
return
cert
return
cert
@
rpc
@
rpc
def
renewCertificate
(
self
,
cn
)
:
def
renewCertificate
(
self
,
cn
:
str
)
->
bytes
:
with
self
.
lock
:
with
self
.
lock
:
with
self
.
db
as
db
:
with
self
.
db
as
db
:
pem
=
self
.
getCert
(
cn
)
pem
=
self
.
getCert
(
cn
)
...
@@ -579,26 +602,28 @@ class RegistryServer(object):
...
@@ -579,26 +602,28 @@ class RegistryServer(object):
cert
.
get_subject
(),
cert
.
get_pubkey
(),
not_after
)
cert
.
get_subject
(),
cert
.
get_pubkey
(),
not_after
)
@
rpc
@
rpc
def
getCa
(
self
):
def
getCa
(
self
)
->
bytes
:
return
crypto
.
dump_certificate
(
crypto
.
FILETYPE_PEM
,
self
.
cert
.
ca
)
return
crypto
.
dump_certificate
(
crypto
.
FILETYPE_PEM
,
self
.
cert
.
ca
)
@
rpc
@
rpc
def
getDh
(
self
,
cn
)
:
def
getDh
(
self
,
cn
:
str
)
->
bytes
:
with
open
(
self
.
config
.
dh
)
as
f
:
with
open
(
self
.
config
.
dh
,
"rb"
)
as
f
:
return
f
.
read
()
return
f
.
read
()
@
rpc
@
rpc
def
getNetworkConfig
(
self
,
cn
)
:
def
getNetworkConfig
(
self
,
cn
:
str
)
->
bytes
:
with
self
.
lock
:
with
self
.
lock
:
cert
=
self
.
getCert
(
cn
)
cert
=
self
.
getCert
(
cn
)
config
=
self
.
network_config
.
copy
()
config
=
self
.
network_config
.
copy
()
hmac
=
[
self
.
getConfig
(
k
,
None
)
for
k
in
BABEL_HMAC
]
hmac
=
[
self
.
getConfig
(
k
,
None
)
for
k
in
BABEL_HMAC
]
for
i
,
v
in
enumerate
(
v
for
v
in
hmac
if
v
is
not
None
):
for
i
,
v
in
enumerate
(
v
for
v
in
hmac
if
v
is
not
None
):
config
[(
'babel_hmac_sign'
,
'babel_hmac_accept'
)[
i
]]
=
\
config
[(
'babel_hmac_sign'
,
'babel_hmac_accept'
)[
i
]]
=
\
v
and
x509
.
encrypt
(
cert
,
v
).
encode
(
'base64'
)
v
and
base64
.
b64encode
(
x509
.
encrypt
(
cert
,
v
)).
decode
(
)
return
zlib
.
compress
(
json
.
dumps
(
config
))
return
zlib
.
compress
(
json
.
dumps
(
config
)
.
encode
(
"utf-8"
)
)
def
_queryAddress
(
self
,
peer
):
def
_queryAddress
(
self
,
peer
:
str
)
->
str
:
logging
.
info
(
"Querying address for %s/%s %r"
,
int
(
peer
,
2
),
len
(
peer
),
peer
)
self
.
sendto
(
peer
,
1
)
self
.
sendto
(
peer
,
1
)
s
=
self
.
sock
,
s
=
self
.
sock
,
timeout
=
3
timeout
=
3
...
@@ -606,6 +631,7 @@ class RegistryServer(object):
...
@@ -606,6 +631,7 @@ class RegistryServer(object):
# Loop because there may be answers from previous requests.
# Loop because there may be answers from previous requests.
while
select
.
select
(
s
,
(),
(),
timeout
)[
0
]:
while
select
.
select
(
s
,
(),
(),
timeout
)[
0
]:
prefix
,
msg
=
self
.
recv
(
1
)
prefix
,
msg
=
self
.
recv
(
1
)
logging
.
info
(
"* received: %r - %r"
,
prefix
,
msg
)
if
prefix
==
peer
:
if
prefix
==
peer
:
return
msg
return
msg
timeout
=
max
(
0
,
end
-
time
.
time
())
timeout
=
max
(
0
,
end
-
time
.
time
())
...
@@ -613,23 +639,25 @@ class RegistryServer(object):
...
@@ -613,23 +639,25 @@ class RegistryServer(object):
int
(
peer
,
2
),
len
(
peer
))
int
(
peer
,
2
),
len
(
peer
))
@
rpc
@
rpc
def
getCountry
(
self
,
cn
,
address
)
:
def
getCountry
(
self
,
cn
:
str
,
address
:
str
)
->
str
|
None
:
country
=
self
.
_geoiplookup
(
address
)[
0
]
country
=
self
.
_geoiplookup
(
address
)[
0
]
return
None
if
country
==
'*'
else
country
return
None
if
country
==
'*'
else
country
@
rpc
@
rpc
def
getBootstrapPeer
(
self
,
cn
):
def
getBootstrapPeer
(
self
,
cn
:
str
)
->
bytes
|
None
:
logging
.
info
(
"Answering bootstrap peer for %s"
,
cn
)
with
self
.
peers_lock
:
with
self
.
peers_lock
:
age
,
peers
=
self
.
peers
age
,
peers
=
self
.
peers
if
age
<
time
.
time
()
or
not
peers
:
if
age
<
time
.
time
()
or
not
peers
:
self
.
request_dump
()
self
.
request_dump
()
peers
=
[
prefix
peers
=
[
prefix
for
neigh_routes
in
self
.
ctl
.
neighbours
.
iter
values
()
for
neigh_routes
in
self
.
ctl
.
neighbours
.
values
()
for
prefix
in
neigh_routes
[
1
]
for
prefix
in
neigh_routes
[
1
]
if
prefix
]
if
prefix
]
peers
.
append
(
self
.
prefix
)
peers
.
append
(
self
.
prefix
)
random
.
shuffle
(
peers
)
random
.
shuffle
(
peers
)
self
.
peers
=
time
.
time
()
+
60
,
peers
self
.
peers
=
time
.
time
()
+
60
,
peers
logging
.
debug
(
"peers: %r"
,
peers
)
peer
=
peers
.
pop
()
peer
=
peers
.
pop
()
if
peer
==
cn
:
if
peer
==
cn
:
# Very unlikely (e.g. peer restarted with empty cache),
# Very unlikely (e.g. peer restarted with empty cache),
...
@@ -639,6 +667,7 @@ class RegistryServer(object):
...
@@ -639,6 +667,7 @@ class RegistryServer(object):
with
self
.
lock
:
with
self
.
lock
:
msg
=
self
.
_queryAddress
(
peer
)
msg
=
self
.
_queryAddress
(
peer
)
if
msg
is
None
:
if
msg
is
None
:
logging
.
info
(
"No address for %s, returning None"
,
peer
)
return
return
# Remove country for old nodes
# Remove country for old nodes
if
self
.
getPeerProtocol
(
cn
)
<
7
:
if
self
.
getPeerProtocol
(
cn
)
<
7
:
...
@@ -647,10 +676,10 @@ class RegistryServer(object):
...
@@ -647,10 +676,10 @@ class RegistryServer(object):
cert
=
self
.
getCert
(
cn
)
cert
=
self
.
getCert
(
cn
)
msg
=
"%s %s"
%
(
peer
,
msg
)
msg
=
"%s %s"
%
(
peer
,
msg
)
logging
.
info
(
"Sending bootstrap peer: %s"
,
msg
)
logging
.
info
(
"Sending bootstrap peer: %s"
,
msg
)
return
x509
.
encrypt
(
cert
,
msg
)
return
x509
.
encrypt
(
cert
,
msg
.
encode
()
)
@
rpc_private
@
rpc_private
def
revoke
(
self
,
cn_or_serial
):
def
revoke
(
self
,
cn_or_serial
:
int
|
str
):
with
self
.
lock
,
self
.
db
:
with
self
.
lock
,
self
.
db
:
q
=
self
.
db
.
execute
q
=
self
.
db
.
execute
try
:
try
:
...
@@ -671,12 +700,12 @@ class RegistryServer(object):
...
@@ -671,12 +700,12 @@ class RegistryServer(object):
q
(
"INSERT INTO crl VALUES (?,?)"
,
(
serial
,
not_after
))
q
(
"INSERT INTO crl VALUES (?,?)"
,
(
serial
,
not_after
))
self
.
updateNetworkConfig
()
self
.
updateNetworkConfig
()
def
newHMAC
(
self
,
i
,
key
=
None
):
def
newHMAC
(
self
,
i
:
int
,
key
:
bytes
=
None
):
if
key
is
None
:
if
key
is
None
:
key
=
buffer
(
os
.
urandom
(
16
)
)
key
=
os
.
urandom
(
16
)
self
.
setConfig
(
BABEL_HMAC
[
i
],
key
)
self
.
setConfig
(
BABEL_HMAC
[
i
],
key
)
def
delHMAC
(
self
,
i
):
def
delHMAC
(
self
,
i
:
int
):
self
.
db
.
execute
(
"DELETE FROM config WHERE name=?"
,
(
BABEL_HMAC
[
i
],))
self
.
db
.
execute
(
"DELETE FROM config WHERE name=?"
,
(
BABEL_HMAC
[
i
],))
@
rpc_private
@
rpc_private
...
@@ -696,25 +725,26 @@ class RegistryServer(object):
...
@@ -696,25 +725,26 @@ class RegistryServer(object):
else
:
else
:
# Initialization of HMAC on the network
# Initialization of HMAC on the network
self
.
newHMAC
(
1
)
self
.
newHMAC
(
1
)
self
.
newHMAC
(
2
,
''
)
self
.
newHMAC
(
2
,
b
''
)
self
.
increaseVersion
()
self
.
increaseVersion
()
self
.
setConfig
(
'version'
,
buffer
(
self
.
version
)
)
self
.
setConfig
(
'version'
,
self
.
version
)
self
.
network_config
[
'version'
]
=
self
.
version
.
encode
(
'base64'
)
self
.
network_config
[
'version'
]
=
base64
.
b64encode
(
self
.
version
)
self
.
sendto
(
self
.
prefix
,
0
)
self
.
sendto
(
self
.
prefix
,
0
)
@
rpc_private
@
rpc_private
def
getNodePrefix
(
self
,
email
)
:
def
getNodePrefix
(
self
,
email
:
str
)
->
str
|
None
:
with
self
.
lock
,
self
.
db
:
with
self
.
lock
,
self
.
db
:
try
:
try
:
cert
,
=
self
.
db
.
execute
(
"SELECT cert FROM cert WHERE email = ?"
,
cert
,
=
next
(
(
email
,)).
next
()
self
.
db
.
execute
(
"SELECT cert FROM cert WHERE email = ?"
,
(
email
,)))
except
StopIteration
:
except
StopIteration
:
return
return
certificate
=
crypto
.
load_certificate
(
crypto
.
FILETYPE_PEM
,
cert
)
certificate
=
crypto
.
load_certificate
(
crypto
.
FILETYPE_PEM
,
cert
)
return
x509
.
subnetFromCert
(
certificate
)
return
x509
.
subnetFromCert
(
certificate
)
@
rpc_private
@
rpc_private
def
getIPv6Address
(
self
,
email
)
:
def
getIPv6Address
(
self
,
email
:
str
)
->
str
:
cn
=
self
.
getNodePrefix
(
email
)
cn
=
self
.
getNodePrefix
(
email
)
if
cn
:
if
cn
:
return
utils
.
ipFromBin
(
return
utils
.
ipFromBin
(
...
@@ -722,13 +752,13 @@ class RegistryServer(object):
...
@@ -722,13 +752,13 @@ class RegistryServer(object):
+
utils
.
binFromSubnet
(
cn
))
+
utils
.
binFromSubnet
(
cn
))
@
rpc_private
@
rpc_private
def
getIPv4Information
(
self
,
email
)
:
def
getIPv4Information
(
self
,
email
:
str
)
->
str
|
None
:
peer
=
self
.
getNodePrefix
(
email
)
peer
=
self
.
getNodePrefix
(
email
)
if
peer
:
if
peer
:
peer
=
utils
.
binFromSubnet
(
peer
)
peer
=
utils
.
binFromSubnet
(
peer
)
with
self
.
peers_lock
:
with
self
.
peers_lock
:
self
.
request_dump
()
self
.
request_dump
()
for
neigh_routes
in
self
.
ctl
.
neighbours
.
iter
values
():
for
neigh_routes
in
self
.
ctl
.
neighbours
.
values
():
for
prefix
in
neigh_routes
[
1
]:
for
prefix
in
neigh_routes
[
1
]:
if
prefix
==
peer
:
if
prefix
==
peer
:
break
break
...
@@ -736,16 +766,16 @@ class RegistryServer(object):
...
@@ -736,16 +766,16 @@ class RegistryServer(object):
return
return
logging
.
info
(
"%s %s"
,
email
,
peer
)
logging
.
info
(
"%s %s"
,
email
,
peer
)
with
self
.
lock
:
with
self
.
lock
:
msg
=
self
.
_queryAddress
(
peer
)
msg
=
self
.
_queryAddress
(
peer
)
.
decode
()
if
msg
:
if
msg
:
return
msg
.
split
(
','
)[
0
]
return
msg
.
split
(
','
)[
0
]
@
rpc_private
@
rpc_private
def
versions
(
self
):
def
versions
(
self
)
->
str
:
with
self
.
peers_lock
:
with
self
.
peers_lock
:
self
.
request_dump
()
self
.
request_dump
()
peers
=
{
prefix
peers
=
{
prefix
for
neigh_routes
in
self
.
ctl
.
neighbours
.
iter
values
()
for
neigh_routes
in
self
.
ctl
.
neighbours
.
values
()
for
prefix
in
neigh_routes
[
1
]
for
prefix
in
neigh_routes
[
1
]
if
prefix
}
if
prefix
}
peers
.
add
(
self
.
prefix
)
peers
.
add
(
self
.
prefix
)
...
@@ -767,7 +797,8 @@ class RegistryServer(object):
...
@@ -767,7 +797,8 @@ class RegistryServer(object):
return
json
.
dumps
(
peer_dict
)
return
json
.
dumps
(
peer_dict
)
@
rpc_private
@
rpc_private
def
topology
(
self
):
def
topology
(
self
)
->
str
:
logging
.
debug
(
"Computing topology"
)
p
=
lambda
p
:
'%s/%s'
%
(
int
(
p
,
2
),
len
(
p
))
p
=
lambda
p
:
'%s/%s'
%
(
int
(
p
,
2
),
len
(
p
))
peers
=
deque
((
p
(
self
.
prefix
),))
peers
=
deque
((
p
(
self
.
prefix
),))
graph
=
defaultdict
(
set
)
graph
=
defaultdict
(
set
)
...
@@ -777,7 +808,8 @@ class RegistryServer(object):
...
@@ -777,7 +808,8 @@ class RegistryServer(object):
r
,
w
,
_
=
select
.
select
(
s
,
s
if
peers
else
(),
(),
3
)
r
,
w
,
_
=
select
.
select
(
s
,
s
if
peers
else
(),
(),
3
)
if
r
:
if
r
:
prefix
,
x
=
self
.
recv
(
5
)
prefix
,
x
=
self
.
recv
(
5
)
if
prefix
and
x
:
logging
.
debug
(
"Received %s %s"
,
prefix
,
x
)
if
prefix
:
prefix
=
p
(
prefix
)
prefix
=
p
(
prefix
)
x
=
x
.
split
()
x
=
x
.
split
()
try
:
try
:
...
@@ -791,35 +823,46 @@ class RegistryServer(object):
...
@@ -791,35 +823,46 @@ class RegistryServer(object):
graph
[
x
].
add
(
prefix
)
graph
[
x
].
add
(
prefix
)
graph
[
''
].
add
(
prefix
)
graph
[
''
].
add
(
prefix
)
if
w
:
if
w
:
self
.
sendto
(
utils
.
binFromSubnet
(
peers
.
popleft
()),
5
)
first
=
peers
.
popleft
()
logging
.
debug
(
"Sending %s"
,
first
)
self
.
sendto
(
utils
.
binFromSubnet
(
first
),
5
)
elif
not
r
:
elif
not
r
:
logging
.
debug
(
"No more sockets, stopping"
)
break
break
return
json
.
dumps
({
k
:
list
(
v
)
for
k
,
v
in
graph
.
iteritems
()})
return
json
.
dumps
({
k
:
list
(
v
)
for
k
,
v
in
graph
.
items
()})
class
RegistryClient
:
"""
Client for the re6st registry.
class
RegistryClient
(
object
):
Method calls are forwarded to the registry server.
String results are always returned as bytes.
"""
_hmac
=
None
_hmac
=
None
user_agent
=
"re6stnet/%s, %s"
%
(
version
.
version
,
platform
.
platform
())
user_agent
=
"re6stnet/%s, %s"
%
(
version
.
version
,
platform
.
platform
())
def
__init__
(
self
,
url
,
c
ert
=
None
,
auto_close
=
True
):
def
__init__
(
self
,
url
:
str
,
cert
:
x509
.
C
ert
=
None
,
auto_close
=
True
):
self
.
cert
=
cert
self
.
cert
=
cert
self
.
auto_close
=
auto_close
self
.
auto_close
=
auto_close
scheme
,
host
=
splittype
(
url
)
url_parsed
=
urlparse
(
url
)
host
,
path
=
splithost
(
host
)
scheme
=
url_parsed
.
scheme
self
.
_conn
=
dict
(
http
=
httplib
.
HTTPConnection
,
host
=
url_parsed
.
netloc
https
=
httplib
.
HTTPSConnection
,
path
=
url_parsed
.
path
self
.
_conn
=
dict
(
http
=
http
.
client
.
HTTPConnection
,
https
=
http
.
client
.
HTTPSConnection
,
)[
scheme
](
unquote
(
host
),
timeout
=
60
)
)[
scheme
](
unquote
(
host
),
timeout
=
60
)
self
.
_path
=
path
.
rstrip
(
'/'
)
self
.
_path
=
path
.
rstrip
(
'/'
)
def
__getattr__
(
self
,
name
):
def
__getattr__
(
self
,
name
:
str
):
getcallargs
=
getattr
(
RegistryServer
,
name
).
getcallargs
getcallargs
=
getattr
(
RegistryServer
,
name
).
getcallargs
def
rpc
(
*
args
,
**
kw
):
def
rpc
(
*
args
,
**
kw
)
->
bytes
:
kw
=
getcallargs
(
*
args
,
**
kw
)
kw
=
getcallargs
(
*
args
,
**
kw
)
query
=
'/'
+
name
query
=
'/'
+
name
if
kw
:
if
kw
:
if
any
(
type
(
v
)
is
not
str
for
v
in
kw
.
iter
values
()):
if
any
(
not
isinstance
(
v
,
(
str
,
bytes
))
for
v
in
kw
.
values
()):
raise
TypeError
raise
TypeError
(
kw
)
query
+=
'?'
+
urlencode
(
kw
)
query
+=
'?'
+
urlencode
(
kw
)
url
=
self
.
_path
+
query
url
=
self
.
_path
+
query
client_prefix
=
kw
.
get
(
'cn'
)
client_prefix
=
kw
.
get
(
'cn'
)
...
@@ -834,7 +877,8 @@ class RegistryClient(object):
...
@@ -834,7 +877,8 @@ class RegistryClient(object):
n
=
len
(
h
)
//
2
n
=
len
(
h
)
//
2
self
.
cert
.
verify
(
h
[
n
:],
h
[:
n
])
self
.
cert
.
verify
(
h
[
n
:],
h
[:
n
])
key
=
self
.
cert
.
decrypt
(
h
[:
n
])
key
=
self
.
cert
.
decrypt
(
h
[:
n
])
h
=
hmac
.
HMAC
(
key
,
query
,
hashlib
.
sha1
).
digest
()
h
=
(
hmac
.
HMAC
(
key
,
query
.
encode
(),
hashlib
.
sha1
)
.
digest
())
key
=
hashlib
.
sha1
(
key
).
digest
()
key
=
hashlib
.
sha1
(
key
).
digest
()
self
.
_hmac
=
hashlib
.
sha1
(
key
).
digest
()
self
.
_hmac
=
hashlib
.
sha1
(
key
).
digest
()
else
:
else
:
...
@@ -846,14 +890,15 @@ class RegistryClient(object):
...
@@ -846,14 +890,15 @@ class RegistryClient(object):
self
.
_conn
.
endheaders
()
self
.
_conn
.
endheaders
()
response
=
self
.
_conn
.
getresponse
()
response
=
self
.
_conn
.
getresponse
()
body
=
response
.
read
()
body
=
response
.
read
()
if
response
.
status
in
(
httplib
.
OK
,
httplib
.
NO_CONTENT
):
if
response
.
status
in
(
http
.
client
.
OK
,
http
.
client
.
NO_CONTENT
):
if
(
not
client_prefix
or
if
(
not
client_prefix
or
hmac
.
HMAC
(
key
,
body
,
hashlib
.
sha1
).
digest
()
==
hmac
.
HMAC
(
key
,
body
,
hashlib
.
sha1
).
digest
()
==
base64
.
b64decode
(
response
.
msg
[
HMAC_HEADER
])):
base64
.
b64decode
(
response
.
msg
[
HMAC_HEADER
])):
if
self
.
auto_close
and
name
!=
'hello'
:
if
self
.
auto_close
and
name
!=
'hello'
:
self
.
_conn
.
close
()
self
.
_conn
.
close
()
return
body
return
body
elif
response
.
status
==
http
lib
.
FORBIDDEN
:
elif
response
.
status
==
http
.
client
.
FORBIDDEN
:
# XXX: We should improve error handling, while making
# XXX: We should improve error handling, while making
# sure re6st nodes don't crash on temporary errors.
# sure re6st nodes don't crash on temporary errors.
# This is currently good enough for re6st-conf, to
# This is currently good enough for re6st-conf, to
...
@@ -864,7 +909,7 @@ class RegistryClient(object):
...
@@ -864,7 +909,7 @@ class RegistryClient(object):
except
HTTPError
:
except
HTTPError
:
raise
raise
except
Exception
:
except
Exception
:
logging
.
info
(
url
,
exc_info
=
1
)
logging
.
info
(
url
,
exc_info
=
True
)
else
:
else
:
logging
.
info
(
'%s
\
n
Unexpected response %s %s'
,
logging
.
info
(
'%s
\
n
Unexpected response %s %s'
,
url
,
response
.
status
,
response
.
reason
)
url
,
response
.
status
,
response
.
reason
)
...
...
re6st/tests/__init__.py
View file @
e80ca878
from
pathlib
2
import
Path
from
pathlib
import
Path
DEMO_PATH
=
Path
(
__file__
).
resolve
().
parent
.
parent
.
parent
/
"demo"
DEMO_PATH
=
Path
(
__file__
).
resolve
().
parent
.
parent
.
parent
/
"demo"
re6st/tests/test_end2end/test_registry_client.py
View file @
e80ca878
...
@@ -15,7 +15,7 @@ from re6st.tests import DEMO_PATH
...
@@ -15,7 +15,7 @@ from re6st.tests import DEMO_PATH
DH_FILE
=
DEMO_PATH
/
"dh2048.pem"
DH_FILE
=
DEMO_PATH
/
"dh2048.pem"
class
DummyNode
(
object
)
:
class
DummyNode
:
"""fake node to reuse Re6stRegistry
"""fake node to reuse Re6stRegistry
error: node.Popen has destory method which not in subprocess.Popen
error: node.Popen has destory method which not in subprocess.Popen
...
@@ -29,19 +29,20 @@ class TestRegistryClientInteract(unittest.TestCase):
...
@@ -29,19 +29,20 @@ class TestRegistryClientInteract(unittest.TestCase):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
re6st_wrap
.
initial
()
# if running in net ns, set lo up
# if running in net ns, set lo up
subprocess
.
check_call
((
"ip"
,
"link"
,
"set"
,
"lo"
,
"up"
))
subprocess
.
check_call
((
"ip"
,
"link"
,
"set"
,
"lo"
,
"up"
))
def
setUp
(
self
):
def
setUp
(
self
):
re6st_wrap
.
initial
()
self
.
port
=
18080
self
.
port
=
18080
self
.
url
=
"http://localhost:{}/"
.
format
(
self
.
port
)
self
.
url
=
"http://localhost:{}/"
.
format
(
self
.
port
)
# not important, used in network_config check
# not important, used in network_config check
self
.
max_clients
=
10
self
.
max_clients
=
10
def
tearDown
(
self
):
def
tearDown
(
self
):
self
.
server
.
proc
.
terminate
()
with
self
.
server
.
proc
as
p
:
p
.
terminate
()
def
test_1_main
(
self
):
def
test_1_main
(
self
):
""" a client interact a server, no re6stnet node test basic function"""
""" a client interact a server, no re6stnet node test basic function"""
...
@@ -60,7 +61,7 @@ class TestRegistryClientInteract(unittest.TestCase):
...
@@ -60,7 +61,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# read token from db
# read token from db
db
=
sqlite3
.
connect
(
str
(
self
.
server
.
db
),
isolation_level
=
None
)
db
=
sqlite3
.
connect
(
str
(
self
.
server
.
db
),
isolation_level
=
None
)
token
=
None
token
=
None
for
_
in
x
range
(
100
):
for
_
in
range
(
100
):
time
.
sleep
(.
1
)
time
.
sleep
(.
1
)
token
=
db
.
execute
(
"SELECT token FROM token WHERE email=?"
,
token
=
db
.
execute
(
"SELECT token FROM token WHERE email=?"
,
(
email
,)).
fetchone
()
(
email
,)).
fetchone
()
...
@@ -70,7 +71,7 @@ class TestRegistryClientInteract(unittest.TestCase):
...
@@ -70,7 +71,7 @@ class TestRegistryClientInteract(unittest.TestCase):
self
.
fail
(
"Request token failed, no token in database"
)
self
.
fail
(
"Request token failed, no token in database"
)
# token: tuple[unicode,]
# token: tuple[unicode,]
token
=
str
(
token
[
0
])
token
=
str
(
token
[
0
])
self
.
assertEqual
(
client
.
isToken
(
token
),
"1"
)
self
.
assertEqual
(
client
.
isToken
(
token
),
b
"1"
)
# request ca
# request ca
ca
=
client
.
getCa
()
ca
=
client
.
getCa
()
...
@@ -78,7 +79,7 @@ class TestRegistryClientInteract(unittest.TestCase):
...
@@ -78,7 +79,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# request a cert and get cn
# request a cert and get cn
key
,
csr
=
tools
.
generate_csr
()
key
,
csr
=
tools
.
generate_csr
()
cert
=
client
.
requestCertificate
(
token
,
csr
)
cert
=
client
.
requestCertificate
(
token
,
csr
)
self
.
assertEqual
(
client
.
isToken
(
token
),
''
,
"token should be deleted"
)
self
.
assertEqual
(
client
.
isToken
(
token
),
b
''
,
"token should be deleted"
)
# creat x509.cert object
# creat x509.cert object
def
write_to_temp
(
text
):
def
write_to_temp
(
text
):
...
@@ -97,7 +98,7 @@ class TestRegistryClientInteract(unittest.TestCase):
...
@@ -97,7 +98,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# verfiy cn and prefix
# verfiy cn and prefix
prefix
=
client
.
cert
.
prefix
prefix
=
client
.
cert
.
prefix
cn
=
client
.
getNodePrefix
(
email
)
cn
=
client
.
getNodePrefix
(
email
)
.
decode
()
self
.
assertEqual
(
tools
.
prefix2cn
(
prefix
),
cn
)
self
.
assertEqual
(
tools
.
prefix2cn
(
prefix
),
cn
)
# simulate the process in cache
# simulate the process in cache
...
@@ -108,7 +109,7 @@ class TestRegistryClientInteract(unittest.TestCase):
...
@@ -108,7 +109,7 @@ class TestRegistryClientInteract(unittest.TestCase):
# no re6stnet, empty result
# no re6stnet, empty result
bootpeer
=
client
.
getBootstrapPeer
(
prefix
)
bootpeer
=
client
.
getBootstrapPeer
(
prefix
)
self
.
assertEqual
(
bootpeer
,
""
)
self
.
assertEqual
(
bootpeer
,
b
""
)
# server should not die
# server should not die
self
.
assertIsNone
(
self
.
server
.
proc
.
poll
())
self
.
assertIsNone
(
self
.
server
.
proc
.
poll
())
...
...
re6st/tests/test_network/network_build.py
View file @
e80ca878
...
@@ -3,14 +3,9 @@ import logging
...
@@ -3,14 +3,9 @@ import logging
import
nemu
import
nemu
import
time
import
time
import
weakref
import
weakref
from
subprocess
import
PIPE
from
subprocess
import
DEVNULL
,
PIPE
from
pathlib
2
import
Path
from
pathlib
import
Path
from
re6st.tests
import
DEMO_PATH
fix_file
=
DEMO_PATH
/
"fixnemu.py"
# execfile(str(fix_file)) Removed in python3
exec
(
open
(
str
(
fix_file
)).
read
())
IPTABLES
=
'iptables-nft'
IPTABLES
=
'iptables-nft'
class
ConnectableError
(
Exception
):
class
ConnectableError
(
Exception
):
...
@@ -50,7 +45,7 @@ class Node(nemu.Node):
...
@@ -50,7 +45,7 @@ class Node(nemu.Node):
if_s
.
add_v4_address
(
ip
,
prefix_len
=
prefix_len
)
if_s
.
add_v4_address
(
ip
,
prefix_len
=
prefix_len
)
return
if_s
return
if_s
class
NetManager
(
object
)
:
class
NetManager
:
"""contain all the nemu object created, so they can live more time"""
"""contain all the nemu object created, so they can live more time"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
object
=
[]
self
.
object
=
[]
...
@@ -60,10 +55,11 @@ class NetManager(object):
...
@@ -60,10 +55,11 @@ class NetManager(object):
Raise:
Raise:
AssertionError
AssertionError
"""
"""
for
reg
,
nodes
in
self
.
registries
.
ite
rite
ms
():
for
reg
,
nodes
in
self
.
registries
.
items
():
for
node
in
nodes
:
for
node
in
nodes
:
app0
=
node
.
Popen
([
"ping"
,
"-c"
,
"1"
,
reg
.
ip
],
stdout
=
PIPE
)
with
node
.
Popen
([
"ping"
,
"-c"
,
"1"
,
reg
.
ip
],
ret
=
app0
.
wait
()
stdout
=
DEVNULL
)
as
app0
:
ret
=
app0
.
wait
()
if
ret
:
if
ret
:
raise
ConnectableError
(
raise
ConnectableError
(
"network construct failed {} to {}"
.
format
(
node
.
ip
,
reg
.
ip
))
"network construct failed {} to {}"
.
format
(
node
.
ip
,
reg
.
ip
))
...
...
re6st/tests/test_network/re6st_wrap.py
View file @
e80ca878
...
@@ -6,13 +6,15 @@ import ipaddress
...
@@ -6,13 +6,15 @@ import ipaddress
import
json
import
json
import
logging
import
logging
import
re
import
re
import
shlex
import
shutil
import
shutil
import
sqlite3
import
sqlite3
import
sys
import
tempfile
import
tempfile
import
time
import
time
import
weakref
import
weakref
from
subprocess
import
PIPE
from
subprocess
import
PIPE
from
pathlib
2
import
Path
from
pathlib
import
Path
from
re6st.tests
import
tools
from
re6st.tests
import
tools
from
re6st.tests
import
DEMO_PATH
from
re6st.tests
import
DEMO_PATH
...
@@ -20,14 +22,16 @@ from re6st.tests import DEMO_PATH
...
@@ -20,14 +22,16 @@ from re6st.tests import DEMO_PATH
WORK_DIR
=
Path
(
__file__
).
parent
/
"temp_net_test"
WORK_DIR
=
Path
(
__file__
).
parent
/
"temp_net_test"
DH_FILE
=
DEMO_PATH
/
"dh2048.pem"
DH_FILE
=
DEMO_PATH
/
"dh2048.pem"
RE6STNET
=
"python -m re6st.cli.node"
PYTHON
=
shlex
.
quote
(
sys
.
executable
)
RE6ST_REGISTRY
=
"python -m re6st.cli.registry"
RE6STNET
=
PYTHON
+
" -m re6st.cli.node"
RE6ST_CONF
=
"python -m re6st.cli.conf"
RE6ST_REGISTRY
=
PYTHON
+
" -m re6st.cli.registry"
RE6ST_CONF
=
PYTHON
+
" -m re6st.cli.conf"
def
initial
():
def
initial
():
"""create the workplace"""
"""create the workplace"""
if
not
WORK_DIR
.
exists
():
if
WORK_DIR
.
exists
():
WORK_DIR
.
mkdir
()
shutil
.
rmtree
(
str
(
WORK_DIR
))
WORK_DIR
.
mkdir
()
def
ip_to_serial
(
ip6
):
def
ip_to_serial
(
ip6
):
"""convert ipv6 address to serial"""
"""convert ipv6 address to serial"""
...
@@ -36,7 +40,7 @@ def ip_to_serial(ip6):
...
@@ -36,7 +40,7 @@ def ip_to_serial(ip6):
return
int
(
ip6
,
16
)
return
int
(
ip6
,
16
)
class
Re6stRegistry
(
object
)
:
class
Re6stRegistry
:
"""class run a re6st-registry service on a namespace"""
"""class run a re6st-registry service on a namespace"""
registry_seq
=
0
registry_seq
=
0
...
@@ -72,7 +76,7 @@ class Re6stRegistry(object):
...
@@ -72,7 +76,7 @@ class Re6stRegistry(object):
self
.
run
()
self
.
run
()
# wait the servcice started
# wait the servcice started
p
=
self
.
node
.
Popen
([
'python'
,
'-c'
,
"""if 1:
p
=
self
.
node
.
Popen
([
sys
.
executable
,
'-c'
,
"""if 1:
import socket, time
import socket, time
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
while True:
while True:
...
@@ -115,7 +119,7 @@ class Re6stRegistry(object):
...
@@ -115,7 +119,7 @@ class Re6stRegistry(object):
'--client-count'
,
(
self
.
client_number
+
1
)
//
2
,
'--port'
,
self
.
port
]
'--client-count'
,
(
self
.
client_number
+
1
)
//
2
,
'--port'
,
self
.
port
]
#PY3: convert PosixPath to str, can be remove in Python 3
#PY3: convert PosixPath to str, can be remove in Python 3
cmd
=
map
(
str
,
cmd
)
cmd
=
list
(
map
(
str
,
cmd
)
)
cmd
[:
0
]
=
RE6ST_REGISTRY
.
split
()
cmd
[:
0
]
=
RE6ST_REGISTRY
.
split
()
...
@@ -131,15 +135,19 @@ class Re6stRegistry(object):
...
@@ -131,15 +135,19 @@ class Re6stRegistry(object):
if
e
.
errno
!=
errno
.
ENOENT
:
if
e
.
errno
!=
errno
.
ENOENT
:
raise
raise
def
__del__
(
self
):
def
terminate
(
self
):
try
:
try
:
logging
.
debug
(
"teminate process %s"
,
self
.
proc
.
pid
)
logging
.
debug
(
"teminate process %s"
,
self
.
proc
.
pid
)
self
.
proc
.
destroy
()
with
self
.
proc
as
p
:
p
.
destroy
()
except
:
except
:
pass
pass
def
__del__
(
self
):
self
.
terminate
()
class
Re6stNode
(
object
)
:
class
Re6stNode
:
"""class run a re6stnet service on a namespace"""
"""class run a re6stnet service on a namespace"""
node_seq
=
0
node_seq
=
0
...
@@ -210,7 +218,7 @@ class Re6stNode(object):
...
@@ -210,7 +218,7 @@ class Re6stNode(object):
# read token
# read token
db
=
sqlite3
.
connect
(
str
(
self
.
registry
.
db
),
isolation_level
=
None
)
db
=
sqlite3
.
connect
(
str
(
self
.
registry
.
db
),
isolation_level
=
None
)
token
=
None
token
=
None
for
_
in
x
range
(
100
):
for
_
in
range
(
100
):
time
.
sleep
(.
1
)
time
.
sleep
(.
1
)
token
=
db
.
execute
(
"SELECT token FROM token WHERE email=?"
,
token
=
db
.
execute
(
"SELECT token FROM token WHERE email=?"
,
(
self
.
email
,)).
fetchone
()
(
self
.
email
,)).
fetchone
()
...
@@ -223,7 +231,8 @@ class Re6stNode(object):
...
@@ -223,7 +231,8 @@ class Re6stNode(object):
out
,
_
=
p
.
communicate
(
str
(
token
[
0
]))
out
,
_
=
p
.
communicate
(
str
(
token
[
0
]))
# logging.debug("re6st-conf output: {}".format(out))
# logging.debug("re6st-conf output: {}".format(out))
# find the ipv6 subnet of node
# find the ipv6 subnet of node
self
.
ip6
=
re
.
search
(
'(?<=subnet: )[0-9:a-z]+'
,
out
).
group
(
0
)
self
.
ip6
=
re
.
search
(
'(?<=subnet: )[0-9:a-z]+'
,
out
.
decode
(
"utf-8"
)).
group
(
0
)
data
=
{
'ip6'
:
self
.
ip6
,
'hash'
:
self
.
registry
.
ident
}
data
=
{
'ip6'
:
self
.
ip6
,
'hash'
:
self
.
registry
.
ident
}
with
open
(
str
(
self
.
data_file
),
'w'
)
as
f
:
with
open
(
str
(
self
.
data_file
),
'w'
)
as
f
:
json
.
dump
(
data
,
f
)
json
.
dump
(
data
,
f
)
...
@@ -236,7 +245,7 @@ class Re6stNode(object):
...
@@ -236,7 +245,7 @@ class Re6stNode(object):
'--key'
,
self
.
key
,
'-v4'
,
'--registry'
,
self
.
registry
.
url
,
'--key'
,
self
.
key
,
'-v4'
,
'--registry'
,
self
.
registry
.
url
,
'--console'
,
self
.
console
]
'--console'
,
self
.
console
]
#PY3: same as for Re6stRegistry.run
#PY3: same as for Re6stRegistry.run
cmd
=
map
(
str
,
cmd
)
cmd
=
list
(
map
(
str
,
cmd
)
)
cmd
[:
0
]
=
RE6STNET
.
split
()
cmd
[:
0
]
=
RE6STNET
.
split
()
cmd
+=
args
cmd
+=
args
...
@@ -260,7 +269,8 @@ class Re6stNode(object):
...
@@ -260,7 +269,8 @@ class Re6stNode(object):
def
stop
(
self
):
def
stop
(
self
):
"""stop running re6stnet process"""
"""stop running re6stnet process"""
logging
.
debug
(
"%s teminate process %s"
,
self
.
name
,
self
.
proc
.
pid
)
logging
.
debug
(
"%s teminate process %s"
,
self
.
name
,
self
.
proc
.
pid
)
self
.
proc
.
destroy
()
with
self
.
proc
as
p
:
p
.
destroy
()
def
__del__
(
self
):
def
__del__
(
self
):
"""teminate process and rm temp dir"""
"""teminate process and rm temp dir"""
...
...
re6st/tests/test_network/test_net.py
View file @
e80ca878
"""contain ping-test for re6set net"""
"""contain ping-test for re6set net"""
import
os
import
os
import
sys
import
unittest
import
unittest
import
time
import
time
import
psutil
import
psutil
import
logging
import
logging
import
random
import
random
from
pathlib
2
import
Path
from
pathlib
import
Path
import
network_build
from
.
import
network_build
,
re6st_wrap
import
re6st_wrap
PING_PATH
=
str
(
Path
(
__file__
).
parent
.
resolve
()
/
"ping.py"
)
PING_PATH
=
str
(
Path
(
__file__
).
parent
.
resolve
()
/
"ping.py"
)
def
deploy_re6st
(
nm
,
recreate
=
False
):
net
=
nm
.
registries
nodes
=
[]
registries
=
[]
re6st_wrap
.
Re6stRegistry
.
registry_seq
=
0
re6st_wrap
.
Re6stNode
.
node_seq
=
0
for
registry
in
net
:
reg
=
re6st_wrap
.
Re6stRegistry
(
registry
,
"2001:db8:42::"
,
len
(
net
[
registry
]),
recreate
=
recreate
)
reg_node
=
re6st_wrap
.
Re6stNode
(
registry
,
reg
,
name
=
reg
.
name
)
registries
.
append
(
reg
)
reg_node
.
run
(
"--gateway"
,
"--disable-proto"
,
"none"
,
"--ip"
,
registry
.
ip
)
nodes
.
append
(
reg_node
)
for
m
in
net
[
registry
]:
node
=
re6st_wrap
.
Re6stNode
(
m
,
reg
)
node
.
run
(
"-i"
+
m
.
iface
.
name
)
nodes
.
append
(
node
)
return
nodes
,
registries
def
wait_stable
(
nodes
,
timeout
=
240
):
def
wait_stable
(
nodes
,
timeout
=
240
):
"""try use ping6 from each node to the other until ping success to all the
"""try use ping6 from each node to the other until ping success to all the
other nodes
other nodes
...
@@ -47,12 +28,13 @@ def wait_stable(nodes, timeout=240):
...
@@ -47,12 +28,13 @@ def wait_stable(nodes, timeout=240):
for
node
in
nodes
:
for
node
in
nodes
:
sub_ips
=
set
(
ips
)
-
{
node
.
ip6
}
sub_ips
=
set
(
ips
)
-
{
node
.
ip6
}
node
.
ping_proc
=
node
.
node
.
Popen
(
node
.
ping_proc
=
node
.
node
.
Popen
(
[
"python"
,
PING_PATH
,
'--retry'
,
'-a'
]
+
list
(
sub_ips
))
[
sys
.
executable
,
PING_PATH
,
'--retry'
,
'-a'
]
+
list
(
sub_ips
),
env
=
os
.
environ
)
# check all the node network can ping each other, in order reverse
# check all the node network can ping each other, in order reverse
unfinished
=
list
(
nodes
)
unfinished
=
list
(
nodes
)
while
unfinished
:
while
unfinished
:
for
i
in
x
range
(
len
(
unfinished
)
-
1
,
-
1
,
-
1
):
for
i
in
range
(
len
(
unfinished
)
-
1
,
-
1
,
-
1
):
node
=
unfinished
[
i
]
node
=
unfinished
[
i
]
if
node
.
ping_proc
.
poll
()
is
not
None
:
if
node
.
ping_proc
.
poll
()
is
not
None
:
logging
.
debug
(
"%s 's network is stable"
,
node
.
name
)
logging
.
debug
(
"%s 's network is stable"
,
node
.
name
)
...
@@ -75,8 +57,42 @@ class TestNet(unittest.TestCase):
...
@@ -75,8 +57,42 @@ class TestNet(unittest.TestCase):
def
setUpClass
(
cls
):
def
setUpClass
(
cls
):
"""create work dir"""
"""create work dir"""
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
def
setUp
(
self
):
re6st_wrap
.
initial
()
re6st_wrap
.
initial
()
def
deploy_re6st
(
self
,
nm
,
recreate
=
False
):
net
=
nm
.
registries
nodes
=
[]
registries
=
[]
re6st_wrap
.
Re6stRegistry
.
registry_seq
=
0
re6st_wrap
.
Re6stNode
.
node_seq
=
0
for
registry
in
net
:
reg
=
re6st_wrap
.
Re6stRegistry
(
registry
,
"2001:db8:42::"
,
len
(
net
[
registry
]),
recreate
=
recreate
)
reg_node
=
re6st_wrap
.
Re6stNode
(
registry
,
reg
,
name
=
reg
.
name
)
registries
.
append
(
reg
)
reg_node
.
run
(
"--gateway"
,
"--disable-proto"
,
"none"
,
"--ip"
,
registry
.
ip
)
nodes
.
append
(
reg_node
)
for
m
in
net
[
registry
]:
node
=
re6st_wrap
.
Re6stNode
(
m
,
reg
)
node
.
run
(
"-i"
+
m
.
iface
.
name
)
nodes
.
append
(
node
)
def
clean_re6st
():
for
node
in
nodes
:
node
.
node
.
destroy
()
node
.
stop
()
for
reg
in
registries
:
reg
.
terminate
()
self
.
addCleanup
(
clean_re6st
)
return
nodes
,
registries
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
"""watch any process leaked after tests"""
"""watch any process leaked after tests"""
...
@@ -94,7 +110,7 @@ class TestNet(unittest.TestCase):
...
@@ -94,7 +110,7 @@ class TestNet(unittest.TestCase):
"""create a network in a net segment, test the connectivity by ping
"""create a network in a net segment, test the connectivity by ping
"""
"""
nm
=
network_build
.
net_route
()
nm
=
network_build
.
net_route
()
nodes
,
_
=
deploy_re6st
(
nm
)
nodes
,
_
=
self
.
deploy_re6st
(
nm
)
wait_stable
(
nodes
,
40
)
wait_stable
(
nodes
,
40
)
time
.
sleep
(
10
)
time
.
sleep
(
10
)
...
@@ -107,7 +123,7 @@ class TestNet(unittest.TestCase):
...
@@ -107,7 +123,7 @@ class TestNet(unittest.TestCase):
then test if network recover, this test seems always failed
then test if network recover, this test seems always failed
"""
"""
nm
=
network_build
.
net_demo
()
nm
=
network_build
.
net_demo
()
nodes
,
_
=
deploy_re6st
(
nm
)
nodes
,
_
=
self
.
deploy_re6st
(
nm
)
wait_stable
(
nodes
,
100
)
wait_stable
(
nodes
,
100
)
...
@@ -126,7 +142,7 @@ class TestNet(unittest.TestCase):
...
@@ -126,7 +142,7 @@ class TestNet(unittest.TestCase):
then test if network recover,
then test if network recover,
"""
"""
nm
=
network_build
.
net_route
()
nm
=
network_build
.
net_route
()
nodes
,
_
=
deploy_re6st
(
nm
)
nodes
,
_
=
self
.
deploy_re6st
(
nm
)
wait_stable
(
nodes
,
40
)
wait_stable
(
nodes
,
40
)
...
...
re6st/tests/test_unit/test_conf.py
View file @
e80ca878
#!/usr/bin/
python2
#!/usr/bin/
env python3
""" unit test for re6st-conf
""" unit test for re6st-conf
"""
"""
...
@@ -6,7 +6,7 @@ import os
...
@@ -6,7 +6,7 @@ import os
import
sys
import
sys
import
unittest
import
unittest
from
shutil
import
rmtree
from
shutil
import
rmtree
from
StringIO
import
StringIO
from
io
import
StringIO
from
mock
import
patch
from
mock
import
patch
from
OpenSSL
import
crypto
from
OpenSSL
import
crypto
...
@@ -36,7 +36,7 @@ class TestConf(unittest.TestCase):
...
@@ -36,7 +36,7 @@ class TestConf(unittest.TestCase):
# mocked server cert and pkey
# mocked server cert and pkey
cls
.
pkey
,
cls
.
cert
=
create_ca_file
(
os
.
devnull
,
os
.
devnull
)
cls
.
pkey
,
cls
.
cert
=
create_ca_file
(
os
.
devnull
,
os
.
devnull
)
cls
.
fingerprint
=
""
.
join
(
cls
.
cert
.
digest
(
"sha1"
).
split
(
":"
))
cls
.
fingerprint
=
""
.
join
(
cls
.
cert
.
digest
(
"sha1"
).
decode
(
).
split
(
":"
))
# client.getCa should return a string form cert
# client.getCa should return a string form cert
cls
.
cert
=
crypto
.
dump_certificate
(
crypto
.
FILETYPE_PEM
,
cls
.
cert
)
cls
.
cert
=
crypto
.
dump_certificate
(
crypto
.
FILETYPE_PEM
,
cls
.
cert
)
...
@@ -72,7 +72,7 @@ class TestConf(unittest.TestCase):
...
@@ -72,7 +72,7 @@ class TestConf(unittest.TestCase):
# go back to original dir
# go back to original dir
os
.
chdir
(
self
.
origin_dir
)
os
.
chdir
(
self
.
origin_dir
)
@
patch
(
"
__builtin__.raw_
input"
)
@
patch
(
"
builtins.
input"
)
def
test_basic
(
self
,
mock_raw_input
):
def
test_basic
(
self
,
mock_raw_input
):
""" go through all the step
""" go through all the step
getCa, requestToken, requestCertificate
getCa, requestToken, requestCertificate
...
...
re6st/tests/test_unit/test_registry.py
View file @
e80ca878
...
@@ -3,7 +3,7 @@ import os
...
@@ -3,7 +3,7 @@ import os
import
random
import
random
import
string
import
string
import
json
import
json
import
http
lib
import
http
.client
import
base64
import
base64
import
unittest
import
unittest
import
hmac
import
hmac
...
@@ -11,18 +11,21 @@ import hashlib
...
@@ -11,18 +11,21 @@ import hashlib
import
time
import
time
import
tempfile
import
tempfile
from
argparse
import
Namespace
from
argparse
import
Namespace
from
sqlite3
import
Cursor
from
OpenSSL
import
crypto
from
OpenSSL
import
crypto
from
mock
import
Mock
,
patch
from
mock
import
Mock
,
patch
from
pathlib
2
import
Path
from
pathlib
import
Path
from
re6st
import
registry
from
re6st
import
registry
,
x509
from
re6st.tests.tools
import
*
from
re6st.tests.tools
import
*
from
re6st.tests
import
DEMO_PATH
from
re6st.tests
import
DEMO_PATH
# TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer
# TODO test for request_dump, requestToken, getNetworkConfig, getBoostrapPeer
# getIPV4Information, versions
# getIPV4Information, versions
def
load_config
(
filename
=
"registry.json"
)
:
def
load_config
(
filename
:
str
=
"registry.json"
)
->
Namespace
:
with
open
(
filename
)
as
f
:
with
open
(
filename
)
as
f
:
config
=
json
.
load
(
f
)
config
=
json
.
load
(
f
)
config
[
"dh"
]
=
DEMO_PATH
/
"dh2048.pem"
config
[
"dh"
]
=
DEMO_PATH
/
"dh2048.pem"
...
@@ -36,23 +39,25 @@ def load_config(filename="registry.json"):
...
@@ -36,23 +39,25 @@ def load_config(filename="registry.json"):
return
Namespace
(
**
config
)
return
Namespace
(
**
config
)
def
get_cert
(
cur
,
prefix
):
def
get_cert
(
cur
:
Cursor
,
prefix
:
str
):
res
=
cur
.
execute
(
res
=
cur
.
execute
(
"SELECT cert FROM cert WHERE prefix=?"
,
(
prefix
,)).
fetchone
()
"SELECT cert FROM cert WHERE prefix=?"
,
(
prefix
,)).
fetchone
()
return
res
[
0
]
return
res
[
0
]
def
insert_cert
(
cur
,
ca
,
prefix
,
not_after
=
None
,
email
=
None
):
def
insert_cert
(
cur
:
Cursor
,
ca
:
x509
.
Cert
,
prefix
:
str
,
not_after
=
None
,
email
=
None
):
key
,
csr
=
generate_csr
()
key
,
csr
=
generate_csr
()
cert
=
generate_cert
(
ca
.
ca
,
ca
.
key
,
csr
,
prefix
,
insert_cert
.
serial
,
not_after
)
cert
=
generate_cert
(
ca
.
ca
,
ca
.
key
,
csr
,
prefix
,
insert_cert
.
serial
,
not_after
)
cur
.
execute
(
"INSERT INTO cert VALUES (?,?,?)"
,
(
prefix
,
email
,
cert
))
cur
.
execute
(
"INSERT INTO cert VALUES (?,?,?)"
,
(
prefix
,
email
,
cert
))
insert_cert
.
serial
+=
1
insert_cert
.
serial
+=
1
return
key
,
cert
return
key
,
cert
insert_cert
.
serial
=
0
insert_cert
.
serial
=
0
def
delete_cert
(
cur
,
prefix
):
def
delete_cert
(
cur
:
Cursor
,
prefix
:
str
):
cur
.
execute
(
"DELETE FROM cert WHERE prefix = ?"
,
(
prefix
,))
cur
.
execute
(
"DELETE FROM cert WHERE prefix = ?"
,
(
prefix
,))
...
@@ -68,6 +73,7 @@ class TestRegistryServer(unittest.TestCase):
...
@@ -68,6 +73,7 @@ class TestRegistryServer(unittest.TestCase):
@
classmethod
@
classmethod
def
tearDownClass
(
cls
):
def
tearDownClass
(
cls
):
cls
.
server
.
close
()
# remove database
# remove database
for
file
in
[
cls
.
config
.
db
,
cls
.
config
.
ca
,
cls
.
config
.
key
]:
for
file
in
[
cls
.
config
.
db
,
cls
.
config
.
ca
,
cls
.
config
.
key
]:
try
:
try
:
...
@@ -80,14 +86,23 @@ class TestRegistryServer(unittest.TestCase):
...
@@ -80,14 +86,23 @@ class TestRegistryServer(unittest.TestCase):
+
"@mail.com"
+
"@mail.com"
def
test_recv
(
self
):
def
test_recv
(
self
):
recv
=
self
.
server
.
sock
.
recv
=
Mock
()
side_effect
=
iter
([
recv
.
side_effect
=
[
"0001001001001a_msg"
,
"0001001001001a_msg"
,
"0001001001002
\
000
1dqdq"
,
"0001001001002
\
000
1dqdq"
,
"0001001001001
\
000
a_msg"
,
"0001001001001
\
000
a_msg"
,
"0001001001001
\
000
\
4
a_msg"
,
"0001001001001
\
000
\
4
a_msg"
,
"0000000000000
\
0
"
# ERROR, IndexError: msg is null
"0000000000000
\
0
"
# ERROR, IndexError: msg is null
]
])
class
SocketProxy
:
def
__init__
(
self
,
wrappee
):
self
.
wrappee
=
wrappee
self
.
recv
=
lambda
_
:
next
(
side_effect
)
def
__getattr__
(
self
,
attr
):
return
getattr
(
self
.
wrappee
,
attr
)
self
.
server
.
sock
=
SocketProxy
(
self
.
server
.
sock
)
try
:
try
:
res1
=
self
.
server
.
recv
(
4
)
res1
=
self
.
server
.
recv
(
4
)
...
@@ -115,7 +130,7 @@ class TestRegistryServer(unittest.TestCase):
...
@@ -115,7 +130,7 @@ class TestRegistryServer(unittest.TestCase):
now
=
int
(
time
.
time
())
-
self
.
config
.
grace_period
+
20
now
=
int
(
time
.
time
())
-
self
.
config
.
grace_period
+
20
# makeup data
# makeup data
insert_cert
(
cur
,
self
.
server
.
cert
,
prefix_old
,
1
)
insert_cert
(
cur
,
self
.
server
.
cert
,
prefix_old
,
1
)
insert_cert
(
cur
,
self
.
server
.
cert
,
prefix
,
now
-
1
)
insert_cert
(
cur
,
self
.
server
.
cert
,
prefix
,
now
-
1
)
cur
.
execute
(
"INSERT INTO token VALUES (?,?,?,?)"
,
cur
.
execute
(
"INSERT INTO token VALUES (?,?,?,?)"
,
(
token_old
,
self
.
email
,
4
,
2
))
(
token_old
,
self
.
email
,
4
,
2
))
cur
.
execute
(
"INSERT INTO token VALUES (?,?,?,?)"
,
cur
.
execute
(
"INSERT INTO token VALUES (?,?,?,?)"
,
...
@@ -143,16 +158,16 @@ class TestRegistryServer(unittest.TestCase):
...
@@ -143,16 +158,16 @@ class TestRegistryServer(unittest.TestCase):
prefix
=
"0000000011111111"
prefix
=
"0000000011111111"
method
=
"func"
method
=
"func"
protocol
=
7
protocol
=
7
params
=
{
"cn"
:
prefix
,
"a"
:
1
,
"b"
:
2
}
params
=
{
"cn"
:
prefix
,
"a"
:
1
,
"b"
:
2
}
func
.
getcallargs
.
return_value
=
params
func
.
getcallargs
.
return_value
=
params
del
func
.
_private
del
func
.
_private
func
.
return_value
=
result
=
"this_is_a_result"
func
.
return_value
=
result
=
b
"this_is_a_result"
key
=
"this_is_a_key"
key
=
b
"this_is_a_key"
self
.
server
.
sessions
[
prefix
]
=
[(
key
,
protocol
)]
self
.
server
.
sessions
[
prefix
]
=
[(
key
,
protocol
)]
request
=
Mock
()
request
=
Mock
()
request
.
path
=
"/func?a=1&b=2&cn=0000000011111111"
request
.
path
=
"/func?a=1&b=2&cn=0000000011111111"
request
.
headers
=
{
registry
.
HMAC_HEADER
:
base64
.
b64encode
(
request
.
headers
=
{
registry
.
HMAC_HEADER
:
base64
.
b64encode
(
hmac
.
HMAC
(
key
,
request
.
path
,
hashlib
.
sha1
).
digest
())}
hmac
.
HMAC
(
key
,
request
.
path
.
encode
()
,
hashlib
.
sha1
).
digest
())}
self
.
server
.
handle_request
(
request
,
method
,
params
)
self
.
server
.
handle_request
(
request
,
method
,
params
)
...
@@ -162,11 +177,12 @@ class TestRegistryServer(unittest.TestCase):
...
@@ -162,11 +177,12 @@ class TestRegistryServer(unittest.TestCase):
[(
hashlib
.
sha1
(
key
).
digest
(),
protocol
)])
[(
hashlib
.
sha1
(
key
).
digest
(),
protocol
)])
func
.
assert_called_once_with
(
**
params
)
func
.
assert_called_once_with
(
**
params
)
# http response check
# http response check
request
.
send_response
.
assert_called_once_with
(
http
lib
.
OK
)
request
.
send_response
.
assert_called_once_with
(
http
.
client
.
OK
)
request
.
send_header
.
assert_any_call
(
"Content-Length"
,
str
(
len
(
result
)))
request
.
send_header
.
assert_any_call
(
"Content-Length"
,
str
(
len
(
result
)))
request
.
send_header
.
assert_any_call
(
request
.
send_header
.
assert_any_call
(
registry
.
HMAC_HEADER
,
registry
.
HMAC_HEADER
,
base64
.
b64encode
(
hmac
.
HMAC
(
key
,
result
,
hashlib
.
sha1
).
digest
()))
base64
.
b64encode
(
hmac
.
HMAC
(
key
,
result
,
hashlib
.
sha1
).
digest
())
.
decode
(
"ascii"
))
request
.
wfile
.
write
.
assert_called_once_with
(
result
)
request
.
wfile
.
write
.
assert_called_once_with
(
result
)
# remove the create session \n
# remove the create session \n
...
@@ -176,7 +192,7 @@ class TestRegistryServer(unittest.TestCase):
...
@@ -176,7 +192,7 @@ class TestRegistryServer(unittest.TestCase):
def
test_handle_request_private
(
self
,
func
):
def
test_handle_request_private
(
self
,
func
):
"""case request with _private attr"""
"""case request with _private attr"""
method
=
"func"
method
=
"func"
params
=
{
"a"
:
1
,
"b"
:
2
}
params
=
{
"a"
:
1
,
"b"
:
2
}
func
.
getcallargs
.
return_value
=
params
func
.
getcallargs
.
return_value
=
params
func
.
return_value
=
None
func
.
return_value
=
None
request_good
=
Mock
()
request_good
=
Mock
()
...
@@ -189,8 +205,8 @@ class TestRegistryServer(unittest.TestCase):
...
@@ -189,8 +205,8 @@ class TestRegistryServer(unittest.TestCase):
self
.
server
.
handle_request
(
request_bad
,
method
,
params
)
self
.
server
.
handle_request
(
request_bad
,
method
,
params
)
func
.
assert_called_once_with
(
**
params
)
func
.
assert_called_once_with
(
**
params
)
request_bad
.
send_error
.
assert_called_once_with
(
http
lib
.
FORBIDDEN
)
request_bad
.
send_error
.
assert_called_once_with
(
http
.
client
.
FORBIDDEN
)
request_good
.
send_response
.
assert_called_once_with
(
http
lib
.
NO_CONTENT
)
request_good
.
send_response
.
assert_called_once_with
(
http
.
client
.
NO_CONTENT
)
# will cause valueError, if a node send hello twice to a registry
# will cause valueError, if a node send hello twice to a registry
def
test_getPeerProtocol
(
self
):
def
test_getPeerProtocol
(
self
):
...
@@ -213,7 +229,7 @@ class TestRegistryServer(unittest.TestCase):
...
@@ -213,7 +229,7 @@ class TestRegistryServer(unittest.TestCase):
res
=
self
.
server
.
hello
(
prefix
,
protocol
=
protocol
)
res
=
self
.
server
.
hello
(
prefix
,
protocol
=
protocol
)
# decrypt
# decrypt
length
=
len
(
res
)
/
2
length
=
len
(
res
)
//
2
key
,
sign
=
res
[:
length
],
res
[
length
:]
key
,
sign
=
res
[:
length
],
res
[
length
:]
key
=
decrypt
(
pkey
,
key
)
key
=
decrypt
(
pkey
,
key
)
self
.
assertEqual
(
self
.
server
.
sessions
[
prefix
][
-
1
][
0
],
key
,
self
.
assertEqual
(
self
.
server
.
sessions
[
prefix
][
-
1
][
0
],
key
,
...
@@ -282,7 +298,7 @@ class TestRegistryServer(unittest.TestCase):
...
@@ -282,7 +298,7 @@ class TestRegistryServer(unittest.TestCase):
nb_less
=
0
nb_less
=
0
for
cert
in
self
.
server
.
iterCert
():
for
cert
in
self
.
server
.
iterCert
():
s
=
cert
[
0
].
get_subject
().
serialNumber
s
=
cert
[
0
].
get_subject
().
serialNumber
if
(
s
and
int
(
s
)
<=
serial
)
:
if
s
and
int
(
s
)
<=
serial
:
nb_less
+=
1
nb_less
+=
1
self
.
assertEqual
(
nb_less
,
serial
)
self
.
assertEqual
(
nb_less
,
serial
)
...
@@ -378,7 +394,7 @@ class TestRegistryServer(unittest.TestCase):
...
@@ -378,7 +394,7 @@ class TestRegistryServer(unittest.TestCase):
hmacs
=
get_hmac
()
hmacs
=
get_hmac
()
key_1
=
hmacs
[
1
]
key_1
=
hmacs
[
1
]
self
.
assertEqual
(
hmacs
,
[
None
,
key_1
,
''
])
self
.
assertEqual
(
hmacs
,
[
None
,
key_1
,
b
''
])
# step 2
# step 2
self
.
server
.
updateHMAC
()
self
.
server
.
updateHMAC
()
...
@@ -397,12 +413,11 @@ class TestRegistryServer(unittest.TestCase):
...
@@ -397,12 +413,11 @@ class TestRegistryServer(unittest.TestCase):
self
.
assertEqual
(
get_hmac
(),
[
None
,
key_2
,
key_1
])
self
.
assertEqual
(
get_hmac
(),
[
None
,
key_2
,
key_1
])
#
set
p 5
#
ste
p 5
self
.
server
.
updateHMAC
()
self
.
server
.
updateHMAC
()
self
.
assertEqual
(
get_hmac
(),
[
key_2
,
None
,
None
])
self
.
assertEqual
(
get_hmac
(),
[
key_2
,
None
,
None
])
def
test_getNodePrefix
(
self
):
def
test_getNodePrefix
(
self
):
# prefix in short format
# prefix in short format
prefix
=
"0000000101"
prefix
=
"0000000101"
...
@@ -426,20 +441,38 @@ class TestRegistryServer(unittest.TestCase):
...
@@ -426,20 +441,38 @@ class TestRegistryServer(unittest.TestCase):
(
'0000000000000001'
,
'2 0/16 6/16'
)
(
'0000000000000001'
,
'2 0/16 6/16'
)
]
]
recv
.
side_effect
=
recv_case
recv
.
side_effect
=
recv_case
def
side_effct
(
rlist
,
wlist
,
elist
,
timeout
):
def
side_effct
(
rlist
,
wlist
,
elist
,
timeout
):
# rlist is true until the len(recv_case)th call
# rlist is true until the len(recv_case)th call
side_effct
.
i
-=
side_effct
.
i
>
0
side_effct
.
i
-=
side_effct
.
i
>
0
return
[
side_effct
.
i
,
wlist
,
None
]
return
[
side_effct
.
i
,
wlist
,
None
]
side_effct
.
i
=
len
(
recv_case
)
+
1
side_effct
.
i
=
len
(
recv_case
)
+
1
select
.
side_effect
=
side_effct
select
.
side_effect
=
side_effct
res
=
self
.
server
.
topology
()
res
=
self
.
server
.
topology
()
expect_res
=
'{"36893488147419103232/80": ["0/16", "7/16"], '
\
class
CustomDecoder
(
json
.
JSONDecoder
):
'"": ["36893488147419103232/80", "3/16", "1/16", "0/16", "7/16"], '
\
def
__init__
(
self
,
**
kwargs
):
'"4/16": ["0/16"], "3/16": ["0/16", "7/16"], "0/16": ["6/16", "7/16"], '
\
super
().
__init__
(
**
kwargs
)
'"1/16": ["6/16", "0/16"], "7/16": ["6/16", "4/16"]}'''
self
.
parse_array
=
self
.
JSONArray
self.assertEqual(res, expect_res)
self
.
scan_once
=
json
.
scanner
.
py_make_scanner
(
self
)
def
JSONArray
(
self
,
*
args
,
**
kw
):
values
,
end
=
json
.
decoder
.
JSONArray
(
*
args
,
**
kw
)
return
set
(
values
),
end
res
=
json
.
loads
(
res
,
cls
=
CustomDecoder
)
self
.
assertEqual
(
res
,
{
""
:
{
"36893488147419103232/80"
,
"3/16"
,
"1/16"
,
"0/16"
,
"7/16"
},
"0/16"
:
{
"6/16"
,
"7/16"
},
"1/16"
:
{
"6/16"
,
"0/16"
},
"36893488147419103232/80"
:
{
"0/16"
,
"7/16"
},
"3/16"
:
{
"0/16"
,
"7/16"
},
"4/16"
:
{
"0/16"
},
"7/16"
:
{
"6/16"
,
"4/16"
},
})
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
re6st/tests/test_unit/test_registry_client.py
View file @
e80ca878
...
@@ -2,7 +2,7 @@ import sys
...
@@ -2,7 +2,7 @@ import sys
import
os
import
os
import
unittest
import
unittest
import
hmac
import
hmac
import
http
lib
import
http
.client
import
base64
import
base64
import
hashlib
import
hashlib
from
mock
import
Mock
,
patch
from
mock
import
Mock
,
patch
...
@@ -26,15 +26,15 @@ class TestRegistryClient(unittest.TestCase):
...
@@ -26,15 +26,15 @@ class TestRegistryClient(unittest.TestCase):
self
.
assertEqual
(
client1
.
_path
,
"/example"
)
self
.
assertEqual
(
client1
.
_path
,
"/example"
)
self
.
assertEqual
(
client1
.
_conn
.
host
,
"localhost"
)
self
.
assertEqual
(
client1
.
_conn
.
host
,
"localhost"
)
self
.
assertIsInstance
(
client1
.
_conn
,
http
lib
.
HTTPSConnection
)
self
.
assertIsInstance
(
client1
.
_conn
,
http
.
client
.
HTTPSConnection
)
self
.
assertIsInstance
(
client2
.
_conn
,
http
lib
.
HTTPConnection
)
self
.
assertIsInstance
(
client2
.
_conn
,
http
.
client
.
HTTPConnection
)
def
test_rpc_hello
(
self
):
def
test_rpc_hello
(
self
):
prefix
=
"0000000011111111"
prefix
=
"0000000011111111"
protocol
=
"7"
protocol
=
"7"
body
=
"a_hmac_key"
body
=
"a_hmac_key"
query
=
"/hello?client_prefix=0000000011111111&protocol=7"
query
=
"/hello?client_prefix=0000000011111111&protocol=7"
response
=
fakeResponse
(
body
,
http
lib
.
OK
)
response
=
fakeResponse
(
body
,
http
.
client
.
OK
)
self
.
client
.
_conn
.
getresponse
.
return_value
=
response
self
.
client
.
_conn
.
getresponse
.
return_value
=
response
res
=
self
.
client
.
hello
(
prefix
,
protocol
)
res
=
self
.
client
.
hello
(
prefix
,
protocol
)
...
@@ -52,14 +52,15 @@ class TestRegistryClient(unittest.TestCase):
...
@@ -52,14 +52,15 @@ class TestRegistryClient(unittest.TestCase):
self
.
client
.
_hmac
=
None
self
.
client
.
_hmac
=
None
self
.
client
.
hello
=
Mock
(
return_value
=
"aaabbb"
)
self
.
client
.
hello
=
Mock
(
return_value
=
"aaabbb"
)
self
.
client
.
cert
=
Mock
()
self
.
client
.
cert
=
Mock
()
key
=
"this_is_a_key"
key
=
b
"this_is_a_key"
self
.
client
.
cert
.
decrypt
.
return_value
=
key
self
.
client
.
cert
.
decrypt
.
return_value
=
key
h
=
hmac
.
HMAC
(
key
,
query
,
hashlib
.
sha1
).
digest
()
h
=
hmac
.
HMAC
(
key
,
query
.
encode
()
,
hashlib
.
sha1
).
digest
()
key
=
hashlib
.
sha1
(
key
).
digest
()
key
=
hashlib
.
sha1
(
key
).
digest
()
# response part
# response part
body
=
None
body
=
b'this is a body'
response
=
fakeResponse
(
body
,
httplib
.
NO_CONTENT
)
response
=
fakeResponse
(
body
,
http
.
client
.
NO_CONTENT
)
response
.
msg
=
dict
(
Re6stHMAC
=
hmac
.
HMAC
(
key
,
body
,
hashlib
.
sha1
).
digest
())
response
.
msg
=
dict
(
Re6stHMAC
=
base64
.
b64encode
(
hmac
.
HMAC
(
key
,
body
,
hashlib
.
sha1
).
digest
()))
self
.
client
.
_conn
.
getresponse
.
return_value
=
response
self
.
client
.
_conn
.
getresponse
.
return_value
=
response
res
=
self
.
client
.
getNetworkConfig
(
cn
)
res
=
self
.
client
.
getNetworkConfig
(
cn
)
...
...
re6st/tests/test_unit/test_tunnel/test_base_tunnel_manager.py
View file @
e80ca878
#!/usr/bin/
python2
#!/usr/bin/
env python3
import
os
import
os
import
sys
import
sys
import
unittest
import
unittest
...
@@ -67,7 +67,7 @@ class testBaseTunnelManager(unittest.TestCase):
...
@@ -67,7 +67,7 @@ class testBaseTunnelManager(unittest.TestCase):
# @patch("re6st.tunnel.BaseTunnelManager._makeTunnel", create=True)
# @patch("re6st.tunnel.BaseTunnelManager._makeTunnel", create=True)
# def test_processPacket_address_with_msg_peer(self, makeTunnel):
# def test_processPacket_address_with_msg_peer(self, makeTunnel):
# """code is 1, peer and msg not none """
# """code is 1, peer and msg not none """
# c =
chr(1)
# c =
b"\x01"
# msg = "address"
# msg = "address"
# peer = x509.Peer("000001")
# peer = x509.Peer("000001")
# self.tunnel._connecting = {peer}
# self.tunnel._connecting = {peer}
...
@@ -81,7 +81,7 @@ class testBaseTunnelManager(unittest.TestCase):
...
@@ -81,7 +81,7 @@ class testBaseTunnelManager(unittest.TestCase):
def
test_processPacket_address
(
self
):
def
test_processPacket_address
(
self
):
"""code is 1, for address. And peer or msg are none"""
"""code is 1, for address. And peer or msg are none"""
c
=
chr
(
1
)
c
=
b"
\
x01
"
self
.
tunnel
.
_address
=
{
1
:
"1,1"
,
2
:
"2,2"
}
self
.
tunnel
.
_address
=
{
1
:
"1,1"
,
2
:
"2,2"
}
res
=
self
.
tunnel
.
_processPacket
(
c
)
res
=
self
.
tunnel
.
_processPacket
(
c
)
...
@@ -95,7 +95,7 @@ class testBaseTunnelManager(unittest.TestCase):
...
@@ -95,7 +95,7 @@ class testBaseTunnelManager(unittest.TestCase):
and each address join by ;
and each address join by ;
it will truncate address which has more than 3 element
it will truncate address which has more than 3 element
"""
"""
c
=
chr
(
1
)
c
=
b"
\
x01
"
peer
=
x509
.
Peer
(
"000001"
)
peer
=
x509
.
Peer
(
"000001"
)
peer
.
protocol
=
1
peer
.
protocol
=
1
self
.
tunnel
.
_peers
.
append
(
peer
)
self
.
tunnel
.
_peers
.
append
(
peer
)
...
@@ -111,11 +111,11 @@ class testBaseTunnelManager(unittest.TestCase):
...
@@ -111,11 +111,11 @@ class testBaseTunnelManager(unittest.TestCase):
"""code is 0, for network version, peer is not none
"""code is 0, for network version, peer is not none
2 case, one modify the version, one not
2 case, one modify the version, one not
"""
"""
c
=
chr
(
0
)
c
=
b"
\
x00
"
peer
=
x509
.
Peer
(
"000001"
)
peer
=
x509
.
Peer
(
"000001"
)
version1
=
"00003"
version1
=
b
"00003"
version2
=
"00007"
version2
=
b
"00007"
self
.
tunnel
.
_version
=
version3
=
"00005"
self
.
tunnel
.
_version
=
version3
=
b
"00005"
self
.
tunnel
.
_peers
.
append
(
peer
)
self
.
tunnel
.
_peers
.
append
(
peer
)
res
=
self
.
tunnel
.
_processPacket
(
c
+
version1
,
peer
)
res
=
self
.
tunnel
.
_processPacket
(
c
+
version1
,
peer
)
...
...
re6st/tests/test_unit/test_tunnel/test_multi_gateway_manager.py
View file @
e80ca878
#!/usr/bin/
python2
#!/usr/bin/
env python3
import
os
import
os
import
sys
import
sys
import
unittest
import
unittest
...
...
re6st/tests/tools.py
View file @
e80ca878
...
@@ -30,9 +30,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
...
@@ -30,9 +30,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
return
return
crypto.X509Cert in pem format
crypto.X509Cert in pem format
"""
"""
if
type
(
ca
)
is
str
:
if
type
(
ca
)
is
bytes
:
ca
=
crypto
.
load_certificate
(
crypto
.
FILETYPE_PEM
,
ca
)
ca
=
crypto
.
load_certificate
(
crypto
.
FILETYPE_PEM
,
ca
)
if
type
(
ca_key
)
is
str
:
if
type
(
ca_key
)
is
bytes
:
ca_key
=
crypto
.
load_privatekey
(
crypto
.
FILETYPE_PEM
,
ca_key
)
ca_key
=
crypto
.
load_privatekey
(
crypto
.
FILETYPE_PEM
,
ca_key
)
req
=
crypto
.
load_certificate_request
(
crypto
.
FILETYPE_PEM
,
csr
)
req
=
crypto
.
load_certificate_request
(
crypto
.
FILETYPE_PEM
,
csr
)
...
@@ -40,7 +40,7 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
...
@@ -40,7 +40,7 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
cert
.
gmtime_adj_notBefore
(
0
)
cert
.
gmtime_adj_notBefore
(
0
)
if
not_after
:
if
not_after
:
cert
.
set_notAfter
(
cert
.
set_notAfter
(
time
.
strftime
(
"%Y%m%d%H%M%SZ"
,
time
.
gmtime
(
not_after
)))
time
.
strftime
(
"%Y%m%d%H%M%SZ"
,
time
.
gmtime
(
not_after
))
.
encode
()
)
else
:
else
:
cert
.
gmtime_adj_notAfter
(
registry
.
RegistryServer
.
cert_duration
)
cert
.
gmtime_adj_notAfter
(
registry
.
RegistryServer
.
cert_duration
)
subject
=
req
.
get_subject
()
subject
=
req
.
get_subject
()
...
@@ -56,9 +56,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
...
@@ -56,9 +56,9 @@ def generate_cert(ca, ca_key, csr, prefix, serial, not_after=None):
def
create_cert_file
(
pkey_file
,
cert_file
,
ca
,
ca_key
,
prefix
,
serial
):
def
create_cert_file
(
pkey_file
,
cert_file
,
ca
,
ca_key
,
prefix
,
serial
):
pkey
,
csr
=
generate_csr
()
pkey
,
csr
=
generate_csr
()
cert
=
generate_cert
(
ca
,
ca_key
,
csr
,
prefix
,
serial
)
cert
=
generate_cert
(
ca
,
ca_key
,
csr
,
prefix
,
serial
)
with
open
(
pkey_file
,
'w'
)
as
f
:
with
open
(
pkey_file
,
'w
b
'
)
as
f
:
f
.
write
(
pkey
)
f
.
write
(
pkey
)
with
open
(
cert_file
,
'w'
)
as
f
:
with
open
(
cert_file
,
'w
b
'
)
as
f
:
f
.
write
(
cert
)
f
.
write
(
cert
)
return
pkey
,
cert
return
pkey
,
cert
...
@@ -84,26 +84,23 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042):
...
@@ -84,26 +84,23 @@ def create_ca_file(pkey_file, cert_file, serial=0x120010db80042):
cert
.
set_pubkey
(
key
)
cert
.
set_pubkey
(
key
)
cert
.
sign
(
key
,
"sha512"
)
cert
.
sign
(
key
,
"sha512"
)
with
open
(
pkey_file
,
'w'
)
as
pkey_file
:
with
open
(
pkey_file
,
'w
b
'
)
as
pkey_file
:
pkey_file
.
write
(
crypto
.
dump_privatekey
(
crypto
.
FILETYPE_PEM
,
key
))
pkey_file
.
write
(
crypto
.
dump_privatekey
(
crypto
.
FILETYPE_PEM
,
key
))
with
open
(
cert_file
,
'w'
)
as
cert_file
:
with
open
(
cert_file
,
'w
b
'
)
as
cert_file
:
cert_file
.
write
(
crypto
.
dump_certificate
(
crypto
.
FILETYPE_PEM
,
cert
))
cert_file
.
write
(
crypto
.
dump_certificate
(
crypto
.
FILETYPE_PEM
,
cert
))
return
key
,
cert
return
key
,
cert
def
prefix2cn
(
prefix
)
:
def
prefix2cn
(
prefix
:
str
)
->
str
:
return
"%u/%u"
%
(
int
(
prefix
,
2
),
len
(
prefix
))
return
"%u/%u"
%
(
int
(
prefix
,
2
),
len
(
prefix
))
def
serial2prefix
(
serial
)
:
def
serial2prefix
(
serial
:
int
)
->
str
:
return
bin
(
serial
)[
2
:].
rjust
(
16
,
'0'
)
return
bin
(
serial
)[
2
:].
rjust
(
16
,
'0'
)
# pkey: private key
# pkey: private key
def
decrypt
(
pkey
,
incontent
)
:
def
decrypt
(
pkey
:
bytes
,
incontent
:
bytes
)
->
bytes
:
with
open
(
"node.key"
,
'w'
)
as
f
:
with
open
(
"node.key"
,
'w
b
'
)
as
f
:
f
.
write
(
pkey
)
f
.
write
(
pkey
)
args
=
"openssl rsautl -decrypt -inkey node.key"
.
split
()
args
=
"openssl rsautl -decrypt -inkey node.key"
.
split
()
p
=
subprocess
.
Popen
(
return
subprocess
.
run
(
args
,
input
=
incontent
,
stdout
=
subprocess
.
PIPE
).
stdout
args
,
stdin
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
outcontent
,
err
=
p
.
communicate
(
incontent
)
return
outcontent
re6st/tunnel.py
View file @
e80ca878
...
@@ -2,8 +2,13 @@ import errno, json, logging, os, platform, random, socket
...
@@ -2,8 +2,13 @@ import errno, json, logging, os, platform, random, socket
import
subprocess
,
struct
,
sys
,
time
,
weakref
import
subprocess
,
struct
,
sys
,
time
,
weakref
from
collections
import
defaultdict
,
deque
from
collections
import
defaultdict
,
deque
from
bisect
import
bisect
,
insort
from
bisect
import
bisect
,
insort
from
collections.abc
import
Iterator
,
Sequence
from
typing
import
Callable
,
TYPE_CHECKING
from
OpenSSL
import
crypto
from
OpenSSL
import
crypto
from
.
import
ctl
,
plib
,
utils
,
version
,
x509
from
.
import
ctl
,
plib
,
utils
,
version
,
x509
if
TYPE_CHECKING
:
from
.
import
cache
PORT
=
326
PORT
=
326
...
@@ -21,7 +26,8 @@ proto_dict = {
...
@@ -21,7 +26,8 @@ proto_dict = {
proto_dict
[
'tcp'
]
=
proto_dict
[
'tcp4'
]
proto_dict
[
'tcp'
]
=
proto_dict
[
'tcp4'
]
proto_dict
[
'udp'
]
=
proto_dict
[
'udp4'
]
proto_dict
[
'udp'
]
=
proto_dict
[
'udp4'
]
def
resolve
(
ip
,
port
,
proto
):
def
resolve
(
ip
,
port
,
proto
:
str
)
\
->
tuple
[
socket
.
AddressFamily
|
None
,
Iterator
[
str
]]:
try
:
try
:
family
,
proto
=
proto_dict
[
proto
]
family
,
proto
=
proto_dict
[
proto
]
except
KeyError
:
except
KeyError
:
...
@@ -31,16 +37,16 @@ def resolve(ip, port, proto):
...
@@ -31,16 +37,16 @@ def resolve(ip, port, proto):
class
MultiGatewayManager
(
dict
):
class
MultiGatewayManager
(
dict
):
def
__init__
(
self
,
gateway
):
def
__init__
(
self
,
gateway
:
Callable
[[
str
],
str
]
):
self
.
_gw
=
gateway
self
.
_gw
=
gateway
def
_route
(
self
,
cmd
,
dest
,
gw
):
def
_route
(
self
,
cmd
:
str
,
dest
:
str
,
gw
:
str
):
if
gw
:
if
gw
:
cmd
=
'ip'
,
'-4'
,
'route'
,
cmd
,
'%s/32'
%
dest
,
'via'
,
gw
cmd
=
'ip'
,
'-4'
,
'route'
,
cmd
,
'%s/32'
%
dest
,
'via'
,
gw
logging
.
trace
(
'%r'
,
cmd
)
logging
.
trace
(
'%r'
,
cmd
)
subprocess
.
check_call
(
cmd
)
subprocess
.
check_call
(
cmd
)
def
add
(
self
,
dest
,
route
):
def
add
(
self
,
dest
:
str
,
route
:
bool
):
try
:
try
:
self
[
dest
][
1
]
+=
1
self
[
dest
][
1
]
+=
1
except
KeyError
:
except
KeyError
:
...
@@ -48,7 +54,7 @@ class MultiGatewayManager(dict):
...
@@ -48,7 +54,7 @@ class MultiGatewayManager(dict):
self
[
dest
]
=
[
gw
,
0
]
self
[
dest
]
=
[
gw
,
0
]
self
.
_route
(
'add'
,
dest
,
gw
)
self
.
_route
(
'add'
,
dest
,
gw
)
def
remove
(
self
,
dest
):
def
remove
(
self
,
dest
:
str
):
gw
,
count
=
self
[
dest
]
gw
,
count
=
self
[
dest
]
if
count
:
if
count
:
self
[
dest
][
1
]
=
count
-
1
self
[
dest
][
1
]
=
count
-
1
...
@@ -59,13 +65,14 @@ class MultiGatewayManager(dict):
...
@@ -59,13 +65,14 @@ class MultiGatewayManager(dict):
except
:
except
:
pass
pass
class
Connection
(
object
)
:
class
Connection
:
_retry
=
0
_retry
=
0
serial
=
None
serial
=
None
time
=
float
(
'inf'
)
time
=
float
(
'inf'
)
def
__init__
(
self
,
tunnel_manager
,
address_list
,
iface
,
prefix
):
def
__init__
(
self
,
tunnel_manager
:
"TunnelManager"
,
address_list
,
iface
,
prefix
):
self
.
tunnel_manager
=
tunnel_manager
self
.
tunnel_manager
=
tunnel_manager
self
.
address_list
=
address_list
self
.
address_list
=
address_list
self
.
iface
=
iface
self
.
iface
=
iface
...
@@ -94,7 +101,7 @@ class Connection(object):
...
@@ -94,7 +101,7 @@ class Connection(object):
'--remap-usr1'
,
'SIGTERM'
,
'--remap-usr1'
,
'SIGTERM'
,
'--ping-exit'
,
str
(
tm
.
timeout
),
'--ping-exit'
,
str
(
tm
.
timeout
),
'--route-up'
,
'%s %u'
%
(
plib
.
ovpn_client
,
tm
.
write_sock
.
fileno
()),
'--route-up'
,
'%s %u'
%
(
plib
.
ovpn_client
,
tm
.
write_sock
.
fileno
()),
*
tm
.
ovpn_args
)
*
tm
.
ovpn_args
,
pass_fds
=
[
tm
.
write_sock
.
fileno
()]
)
tm
.
resetTunnelRefresh
()
tm
.
resetTunnelRefresh
()
self
.
_retry
+=
1
self
.
_retry
+=
1
...
@@ -109,7 +116,7 @@ class Connection(object):
...
@@ -109,7 +116,7 @@ class Connection(object):
if
i
:
if
i
:
cache
.
addPeer
(
self
.
_prefix
,
','
.
join
(
self
.
address_list
[
i
]),
True
)
cache
.
addPeer
(
self
.
_prefix
,
','
.
join
(
self
.
address_list
[
i
]),
True
)
else
:
else
:
cache
.
connecting
(
self
.
_prefix
,
0
)
cache
.
connecting
(
self
.
_prefix
,
False
)
def
close
(
self
):
def
close
(
self
):
try
:
try
:
...
@@ -132,7 +139,7 @@ class Connection(object):
...
@@ -132,7 +139,7 @@ class Connection(object):
self
.
open
()
self
.
open
()
return
0
return
0
class
TunnelKiller
(
object
)
:
class
TunnelKiller
:
state
=
None
state
=
None
...
@@ -169,7 +176,7 @@ class TunnelKiller(object):
...
@@ -169,7 +176,7 @@ class TunnelKiller(object):
if
(
self
.
address
,
self
.
ifindex
)
in
tm
.
ctl
.
locked
:
if
(
self
.
address
,
self
.
ifindex
)
in
tm
.
ctl
.
locked
:
self
.
state
=
'locked'
self
.
state
=
'locked'
self
.
timeout
=
time
.
time
()
+
2
*
tm
.
timeout
self
.
timeout
=
time
.
time
()
+
2
*
tm
.
timeout
tm
.
sendto
(
self
.
peer
,
'
\
2
'
if
self
.
client
else
'
\
3
'
)
tm
.
sendto
(
self
.
peer
,
b'
\
2
'
if
self
.
client
else
b
'
\
3
'
)
else
:
else
:
self
.
timeout
=
0
self
.
timeout
=
0
...
@@ -186,7 +193,7 @@ class TunnelKiller(object):
...
@@ -186,7 +193,7 @@ class TunnelKiller(object):
locked
=
unlocking
=
lambda
_
:
None
locked
=
unlocking
=
lambda
_
:
None
class
BaseTunnelManager
(
object
)
:
class
BaseTunnelManager
:
# TODO: To minimize downtime when network parameters change, we should do
# TODO: To minimize downtime when network parameters change, we should do
# our best to not restart any process. Ideally, this list should be
# our best to not restart any process. Ideally, this list should be
...
@@ -198,7 +205,8 @@ class BaseTunnelManager(object):
...
@@ -198,7 +205,8 @@ class BaseTunnelManager(object):
_geoiplookup
=
None
_geoiplookup
=
None
_forward
=
None
_forward
=
None
def
__init__
(
self
,
control_socket
,
cache
,
cert
,
conf_country
,
address
=
()):
def
__init__
(
self
,
control_socket
,
cache
:
"cache.Cache"
,
cert
:
x509
.
Cert
,
conf_country
,
address
=
()):
self
.
cert
=
cert
self
.
cert
=
cert
self
.
_network
=
cert
.
network
self
.
_network
=
cert
.
network
self
.
_prefix
=
cert
.
prefix
self
.
_prefix
=
cert
.
prefix
...
@@ -242,14 +250,14 @@ class BaseTunnelManager(object):
...
@@ -242,14 +250,14 @@ class BaseTunnelManager(object):
self
.
_country
=
{}
self
.
_country
=
{}
address_dict
=
{
family
:
self
.
_updateCountry
(
address
)
address_dict
=
{
family
:
self
.
_updateCountry
(
address
)
for
family
,
address
in
address_dict
.
ite
rite
ms
()}
for
family
,
address
in
address_dict
.
items
()}
elif
cache
.
same_country
:
elif
cache
.
same_country
:
sys
.
exit
(
"Can not respect 'same_country' network configuration"
sys
.
exit
(
"Can not respect 'same_country' network configuration"
" (GEOIP2_MMDB not set)"
)
" (GEOIP2_MMDB not set)"
)
self
.
_address
=
{
family
:
utils
.
dump_address
(
address
)
self
.
_address
=
{
family
:
utils
.
dump_address
(
address
)
for
family
,
address
in
address_dict
.
ite
rite
ms
()
for
family
,
address
in
address_dict
.
items
()
if
address
}
if
address
}
cache
.
my_address
=
';'
.
join
(
self
.
_address
.
iter
values
())
cache
.
my_address
=
';'
.
join
(
self
.
_address
.
values
())
self
.
sock
=
socket
.
socket
(
socket
.
AF_INET6
,
self
.
sock
=
socket
.
socket
(
socket
.
AF_INET6
,
socket
.
SOCK_DGRAM
|
socket
.
SOCK_CLOEXEC
)
socket
.
SOCK_DGRAM
|
socket
.
SOCK_CLOEXEC
)
...
@@ -329,7 +337,7 @@ class BaseTunnelManager(object):
...
@@ -329,7 +337,7 @@ class BaseTunnelManager(object):
def
_getPeer
(
self
,
prefix
):
def
_getPeer
(
self
,
prefix
):
return
self
.
_peers
[
bisect
(
self
.
_peers
,
prefix
)
-
1
]
return
self
.
_peers
[
bisect
(
self
.
_peers
,
prefix
)
-
1
]
def
sendto
(
self
,
prefix
,
msg
):
def
sendto
(
self
,
prefix
:
str
,
msg
):
to
=
utils
.
ipFromBin
(
self
.
_network
+
prefix
),
PORT
to
=
utils
.
ipFromBin
(
self
.
_network
+
prefix
),
PORT
peer
=
self
.
_getPeer
(
prefix
)
peer
=
self
.
_getPeer
(
prefix
)
if
peer
.
prefix
!=
prefix
:
if
peer
.
prefix
!=
prefix
:
...
@@ -344,9 +352,11 @@ class BaseTunnelManager(object):
...
@@ -344,9 +352,11 @@ class BaseTunnelManager(object):
peer
.
hello0Sent
()
peer
.
hello0Sent
()
def
_sendto
(
self
,
to
,
msg
,
peer
=
None
):
def
_sendto
(
self
,
to
,
msg
,
peer
=
None
):
if
type
(
msg
)
is
str
:
msg
=
msg
.
encode
()
try
:
try
:
r
=
self
.
sock
.
sendto
(
peer
.
encode
(
msg
)
if
peer
else
msg
,
to
)
r
=
self
.
sock
.
sendto
(
peer
.
encode
(
msg
)
if
peer
else
msg
,
to
)
except
socket
.
error
,
e
:
except
socket
.
error
as
e
:
(
logging
.
info
if
e
.
errno
==
errno
.
ENETUNREACH
else
logging
.
error
)(
(
logging
.
info
if
e
.
errno
==
errno
.
ENETUNREACH
else
logging
.
error
)(
'Failed to send message to %s (%s)'
,
to
,
e
)
'Failed to send message to %s (%s)'
,
to
,
e
)
return
return
...
@@ -359,19 +369,20 @@ class BaseTunnelManager(object):
...
@@ -359,19 +369,20 @@ class BaseTunnelManager(object):
to
=
address
[:
2
]
to
=
address
[:
2
]
if
address
[
0
]
==
'::1'
:
if
address
[
0
]
==
'::1'
:
try
:
try
:
prefix
,
msg
=
msg
.
split
(
'
\
0
'
,
1
)
prefix
,
msg
=
msg
.
split
(
b'
\
0
'
,
1
)
prefix
=
prefix
.
decode
()
int
(
prefix
,
2
)
int
(
prefix
,
2
)
except
ValueError
:
except
ValueError
:
return
return
if
msg
:
if
msg
:
self
.
_forward
=
to
self
.
_forward
=
to
code
=
ord
(
msg
[
0
])
code
=
msg
[
0
]
if
prefix
==
self
.
_prefix
:
if
prefix
==
self
.
_prefix
:
msg
=
self
.
_processPacket
(
msg
)
msg
=
self
.
_processPacket
(
msg
)
if
msg
:
if
msg
:
self
.
_sendto
(
to
,
'%s
\
0
%c%s'
%
(
prefix
,
code
,
msg
))
self
.
_sendto
(
to
,
'%s
\
0
%c%s'
%
(
prefix
,
code
,
msg
))
else
:
else
:
self
.
sendto
(
prefix
,
chr
(
code
|
0x80
)
+
msg
[
1
:])
self
.
sendto
(
prefix
,
bytes
([
code
|
0x80
]
)
+
msg
[
1
:])
return
return
try
:
try
:
sender
=
utils
.
binFromIp
(
address
[
0
])
sender
=
utils
.
binFromIp
(
address
[
0
])
...
@@ -384,7 +395,7 @@ class BaseTunnelManager(object):
...
@@ -384,7 +395,7 @@ class BaseTunnelManager(object):
msg
=
peer
.
decode
(
msg
)
msg
=
peer
.
decode
(
msg
)
if
type
(
msg
)
is
tuple
:
if
type
(
msg
)
is
tuple
:
seqno
,
msg
,
protocol
=
msg
seqno
,
msg
,
protocol
=
msg
def
handleHello
(
peer
,
seqno
,
msg
,
retry
):
def
handleHello
(
peer
,
seqno
,
msg
:
bytes
,
retry
):
if
seqno
==
2
:
if
seqno
==
2
:
i
=
len
(
msg
)
//
2
i
=
len
(
msg
)
//
2
h
=
msg
[:
i
]
h
=
msg
[:
i
]
...
@@ -394,10 +405,10 @@ class BaseTunnelManager(object):
...
@@ -394,10 +405,10 @@ class BaseTunnelManager(object):
except
(
AttributeError
,
crypto
.
Error
,
x509
.
NewSessionError
,
except
(
AttributeError
,
crypto
.
Error
,
x509
.
NewSessionError
,
subprocess
.
CalledProcessError
):
subprocess
.
CalledProcessError
):
logging
.
debug
(
'ignored new session key from %r'
,
logging
.
debug
(
'ignored new session key from %r'
,
address
,
exc_info
=
1
)
address
,
exc_info
=
True
)
return
return
peer
.
version
=
self
.
_version
\
peer
.
version
=
self
.
_version
\
if
self
.
_sendto
(
to
,
'
\
0
'
+
self
.
_version
,
peer
)
else
''
if
self
.
_sendto
(
to
,
b'
\
0
'
+
self
.
_version
,
peer
)
else
b
''
return
return
if
seqno
:
if
seqno
:
h
=
x509
.
fingerprint
(
self
.
cert
.
cert
).
digest
()
h
=
x509
.
fingerprint
(
self
.
cert
.
cert
).
digest
()
...
@@ -410,7 +421,7 @@ class BaseTunnelManager(object):
...
@@ -410,7 +421,7 @@ class BaseTunnelManager(object):
serial
=
cert
.
get_serial_number
()
serial
=
cert
.
get_serial_number
()
if
serial
in
self
.
cache
.
crl
:
if
serial
in
self
.
cache
.
crl
:
raise
ValueError
(
"revoked"
)
raise
ValueError
(
"revoked"
)
except
(
x509
.
VerifyError
,
ValueError
)
,
e
:
except
(
x509
.
VerifyError
,
ValueError
)
as
e
:
if
retry
:
if
retry
:
return
True
return
True
logging
.
debug
(
'ignored invalid certificate from %r (%s)'
,
logging
.
debug
(
'ignored invalid certificate from %r (%s)'
,
...
@@ -444,10 +455,12 @@ class BaseTunnelManager(object):
...
@@ -444,10 +455,12 @@ class BaseTunnelManager(object):
# We got a valid and non-empty message. Always reply
# We got a valid and non-empty message. Always reply
# something so that the sender knows we're still connected.
# something so that the sender knows we're still connected.
answer
=
self
.
_processPacket
(
msg
,
peer
.
prefix
)
answer
=
self
.
_processPacket
(
msg
,
peer
.
prefix
)
self
.
_sendto
(
to
,
msg
[
0
]
+
answer
if
answer
else
""
,
peer
)
self
.
_sendto
(
to
,
msg
[
0
:
1
]
+
answer
.
encode
()
if
answer
else
b''
,
peer
)
def
_processPacket
(
self
,
msg
,
pee
r
=
None
):
def
_processPacket
(
self
,
msg
:
bytes
,
peer
:
x509
.
Peer
|
st
r
=
None
):
c
=
ord
(
msg
[
0
])
c
=
msg
[
0
]
msg
=
msg
[
1
:]
msg
=
msg
[
1
:]
code
=
c
&
0x7f
code
=
c
&
0x7f
if
c
>
0x7f
and
msg
:
if
c
>
0x7f
and
msg
:
...
@@ -456,6 +469,7 @@ class BaseTunnelManager(object):
...
@@ -456,6 +469,7 @@ class BaseTunnelManager(object):
elif
code
==
1
:
# address
elif
code
==
1
:
# address
if
msg
:
if
msg
:
if
peer
:
if
peer
:
msg
=
msg
.
decode
()
self
.
cache
.
addPeer
(
peer
,
msg
)
self
.
cache
.
addPeer
(
peer
,
msg
)
try
:
try
:
self
.
_connecting
.
remove
(
peer
)
self
.
_connecting
.
remove
(
peer
)
...
@@ -467,8 +481,8 @@ class BaseTunnelManager(object):
...
@@ -467,8 +481,8 @@ class BaseTunnelManager(object):
# Don't send country to old nodes
# Don't send country to old nodes
if
self
.
_getPeer
(
peer
).
protocol
<
7
:
if
self
.
_getPeer
(
peer
).
protocol
<
7
:
return
';'
.
join
(
','
.
join
(
a
.
split
(
','
)[:
3
])
for
a
in
return
';'
.
join
(
','
.
join
(
a
.
split
(
','
)[:
3
])
for
a
in
';'
.
join
(
self
.
_address
.
iter
values
()).
split
(
';'
))
';'
.
join
(
self
.
_address
.
values
()).
split
(
';'
))
return
';'
.
join
(
self
.
_address
.
iter
values
())
return
';'
.
join
(
self
.
_address
.
values
())
elif
not
code
:
# network version
elif
not
code
:
# network version
if
peer
:
if
peer
:
try
:
try
:
...
@@ -526,7 +540,7 @@ class BaseTunnelManager(object):
...
@@ -526,7 +540,7 @@ class BaseTunnelManager(object):
if
peer
.
prefix
!=
prefix
:
if
peer
.
prefix
!=
prefix
:
self
.
sendto
(
prefix
,
None
)
self
.
sendto
(
prefix
,
None
)
elif
(
peer
.
version
<
self
.
_version
and
elif
(
peer
.
version
<
self
.
_version
and
self
.
sendto
(
prefix
,
'
\
0
'
+
self
.
_version
)):
self
.
sendto
(
prefix
,
b
'
\
0
'
+
self
.
_version
)):
peer
.
version
=
self
.
_version
peer
.
version
=
self
.
_version
def
broadcastNewVersion
(
self
):
def
broadcastNewVersion
(
self
):
...
@@ -553,18 +567,18 @@ class BaseTunnelManager(object):
...
@@ -553,18 +567,18 @@ class BaseTunnelManager(object):
if
(
not
self
.
NEED_RESTART
.
isdisjoint
(
changed
)
if
(
not
self
.
NEED_RESTART
.
isdisjoint
(
changed
)
or
version
.
protocol
<
self
.
cache
.
min_protocol
or
version
.
protocol
<
self
.
cache
.
min_protocol
# TODO: With --management, we could kill clients without restarting.
# TODO: With --management, we could kill clients without restarting.
or
not
all
(
crl
.
isdisjoint
(
serials
.
iter
values
())
or
not
all
(
crl
.
isdisjoint
(
serials
.
values
())
for
serials
in
self
.
_served
.
iter
values
())):
for
serials
in
self
.
_served
.
values
())):
# Wait at least 1 second to broadcast new version to neighbours.
# Wait at least 1 second to broadcast new version to neighbours.
self
.
selectTimeout
(
time
.
time
()
+
1
+
self
.
cache
.
delay_restart
,
self
.
selectTimeout
(
time
.
time
()
+
1
+
self
.
cache
.
delay_restart
,
self
.
_restart
)
self
.
_restart
)
def
handleServerEvent
(
self
,
sock
):
def
handleServerEvent
(
self
,
sock
:
socket
.
socket
):
event
,
args
=
eval
(
sock
.
recv
(
65536
))
event
,
args
=
eval
(
sock
.
recv
(
65536
))
logging
.
debug
(
"%s%r"
,
event
,
args
)
logging
.
debug
(
"%s%r"
,
event
,
args
)
r
=
getattr
(
self
,
'_ovpn_'
+
event
.
replace
(
'-'
,
'_'
))(
*
args
)
r
=
getattr
(
self
,
'_ovpn_'
+
event
.
replace
(
'-'
,
'_'
))(
*
args
)
if
r
is
not
None
:
if
r
is
not
None
:
sock
.
send
(
chr
(
r
))
sock
.
send
(
bytes
([
r
]
))
def
_ovpn_client_connect
(
self
,
common_name
,
iface
,
serial
,
trusted_ip
):
def
_ovpn_client_connect
(
self
,
common_name
,
iface
,
serial
,
trusted_ip
):
if
serial
in
self
.
cache
.
crl
:
if
serial
in
self
.
cache
.
crl
:
...
@@ -576,7 +590,7 @@ class BaseTunnelManager(object):
...
@@ -576,7 +590,7 @@ class BaseTunnelManager(object):
self
.
_gateway_manager
.
add
(
trusted_ip
,
False
)
self
.
_gateway_manager
.
add
(
trusted_ip
,
False
)
if
prefix
in
self
.
_connection_dict
and
self
.
_prefix
<
prefix
:
if
prefix
in
self
.
_connection_dict
and
self
.
_prefix
<
prefix
:
self
.
_kill
(
prefix
)
self
.
_kill
(
prefix
)
self
.
cache
.
connecting
(
prefix
,
0
)
self
.
cache
.
connecting
(
prefix
,
False
)
return
True
return
True
def
_ovpn_client_disconnect
(
self
,
common_name
,
iface
,
serial
,
trusted_ip
):
def
_ovpn_client_disconnect
(
self
,
common_name
,
iface
,
serial
,
trusted_ip
):
...
@@ -606,7 +620,7 @@ class BaseTunnelManager(object):
...
@@ -606,7 +620,7 @@ class BaseTunnelManager(object):
with
open
(
'/proc/net/ipv6_route'
,
"r"
,
4096
)
as
f
:
with
open
(
'/proc/net/ipv6_route'
,
"r"
,
4096
)
as
f
:
try
:
try
:
routing_table
=
f
.
read
()
routing_table
=
f
.
read
()
except
IOError
,
e
:
except
IOError
as
e
:
# ???: If someone can explain why the kernel sometimes fails
# ???: If someone can explain why the kernel sometimes fails
# even when there's a lot of free memory.
# even when there's a lot of free memory.
if
e
.
errno
!=
errno
.
ENOMEM
:
if
e
.
errno
!=
errno
.
ENOMEM
:
...
@@ -635,7 +649,7 @@ class BaseTunnelManager(object):
...
@@ -635,7 +649,7 @@ class BaseTunnelManager(object):
logging
.
error
(
"%s. Flushing..."
,
msg
)
logging
.
error
(
"%s. Flushing..."
,
msg
)
subprocess
.
call
((
"ip"
,
"-6"
,
"route"
,
"flush"
,
"cached"
))
subprocess
.
call
((
"ip"
,
"-6"
,
"route"
,
"flush"
,
"cached"
))
self
.
sendto
(
self
.
cache
.
registry_prefix
,
self
.
sendto
(
self
.
cache
.
registry_prefix
,
'
\
7
%s (%s)'
%
(
msg
,
os
.
uname
()[
2
]))
b
'
\
7
%s (%s)'
%
(
msg
,
os
.
uname
()[
2
]))
break
break
def
_updateCountry
(
self
,
address
):
def
_updateCountry
(
self
,
address
):
...
@@ -660,7 +674,8 @@ class TunnelManager(BaseTunnelManager):
...
@@ -660,7 +674,8 @@ class TunnelManager(BaseTunnelManager):
def
__init__
(
self
,
control_socket
,
cache
,
cert
,
openvpn_args
,
def
__init__
(
self
,
control_socket
,
cache
,
cert
,
openvpn_args
,
timeout
,
client_count
,
iface_list
,
conf_country
,
address
,
timeout
,
client_count
,
iface_list
,
conf_country
,
address
,
ip_changed
,
remote_gateway
,
disable_proto
,
neighbour_list
=
()):
ip_changed
,
remote_gateway
:
Callable
[[
str
],
str
],
disable_proto
:
Sequence
[
str
],
neighbour_list
=
()):
super
(
TunnelManager
,
self
).
__init__
(
control_socket
,
super
(
TunnelManager
,
self
).
__init__
(
control_socket
,
cache
,
cert
,
conf_country
,
address
)
cache
,
cert
,
conf_country
,
address
)
self
.
ovpn_args
=
openvpn_args
self
.
ovpn_args
=
openvpn_args
...
@@ -683,7 +698,7 @@ class TunnelManager(BaseTunnelManager):
...
@@ -683,7 +698,7 @@ class TunnelManager(BaseTunnelManager):
self
.
_client_count
=
client_count
self
.
_client_count
=
client_count
self
.
new_iface_list
=
deque
(
're6stnet'
+
str
(
i
)
self
.
new_iface_list
=
deque
(
're6stnet'
+
str
(
i
)
for
i
in
x
range
(
1
,
self
.
_client_count
+
1
))
for
i
in
range
(
1
,
self
.
_client_count
+
1
))
self
.
_free_iface_list
=
[]
self
.
_free_iface_list
=
[]
def
close
(
self
):
def
close
(
self
):
...
@@ -752,7 +767,7 @@ class TunnelManager(BaseTunnelManager):
...
@@ -752,7 +767,7 @@ class TunnelManager(BaseTunnelManager):
def
babel_dump
(
self
):
def
babel_dump
(
self
):
t
=
time
.
time
()
t
=
time
.
time
()
if
self
.
_killing
:
if
self
.
_killing
:
for
prefix
,
tunnel_killer
in
self
.
_killing
.
items
(
):
for
prefix
,
tunnel_killer
in
list
(
self
.
_killing
.
items
()
):
if
tunnel_killer
.
timeout
<
t
:
if
tunnel_killer
.
timeout
<
t
:
if
tunnel_killer
.
state
!=
'unlocking'
:
if
tunnel_killer
.
state
!=
'unlocking'
:
logging
.
info
(
logging
.
info
(
...
@@ -780,7 +795,7 @@ class TunnelManager(BaseTunnelManager):
...
@@ -780,7 +795,7 @@ class TunnelManager(BaseTunnelManager):
def
_cleanDeads
(
self
):
def
_cleanDeads
(
self
):
disconnected
=
False
disconnected
=
False
for
prefix
in
self
.
_connection_dict
.
keys
(
):
for
prefix
in
list
(
self
.
_connection_dict
):
status
=
self
.
_connection_dict
[
prefix
].
refresh
()
status
=
self
.
_connection_dict
[
prefix
].
refresh
()
if
status
:
if
status
:
disconnected
|=
status
>
0
disconnected
|=
status
>
0
...
@@ -872,7 +887,7 @@ class TunnelManager(BaseTunnelManager):
...
@@ -872,7 +887,7 @@ class TunnelManager(BaseTunnelManager):
address_list
.
append
((
ip
,
x
[
1
],
x
[
2
]))
address_list
.
append
((
ip
,
x
[
1
],
x
[
2
]))
continue
continue
address_list
.
append
(
x
[:
3
])
address_list
.
append
(
x
[:
3
])
self
.
cache
.
connecting
(
prefix
,
1
)
self
.
cache
.
connecting
(
prefix
,
True
)
if
not
address_list
:
if
not
address_list
:
return
False
return
False
logging
.
info
(
'Establishing a connection with %u/%u'
,
logging
.
info
(
'Establishing a connection with %u/%u'
,
...
@@ -902,7 +917,7 @@ class TunnelManager(BaseTunnelManager):
...
@@ -902,7 +917,7 @@ class TunnelManager(BaseTunnelManager):
neighbours
=
self
.
ctl
.
neighbours
neighbours
=
self
.
ctl
.
neighbours
# Collect all nodes known by Babel
# Collect all nodes known by Babel
peers
=
{
prefix
peers
=
{
prefix
for
neigh_routes
in
neighbours
.
iter
values
()
for
neigh_routes
in
neighbours
.
values
()
for
prefix
in
neigh_routes
[
1
]
for
prefix
in
neigh_routes
[
1
]
if
prefix
}
if
prefix
}
# Keep only distant peers.
# Keep only distant peers.
...
@@ -957,7 +972,7 @@ class TunnelManager(BaseTunnelManager):
...
@@ -957,7 +972,7 @@ class TunnelManager(BaseTunnelManager):
address
=
self
.
cache
.
getAddress
(
peer
)
address
=
self
.
cache
.
getAddress
(
peer
)
if
address
:
if
address
:
count
-=
self
.
_makeTunnel
(
peer
,
address
)
count
-=
self
.
_makeTunnel
(
peer
,
address
)
elif
self
.
sendto
(
peer
,
'
\
1
'
):
elif
self
.
sendto
(
peer
,
b
'
\
1
'
):
self
.
_connecting
.
add
(
peer
)
self
.
_connecting
.
add
(
peer
)
count
-=
1
count
-=
1
elif
distant_peers
is
None
:
elif
distant_peers
is
None
:
...
@@ -987,7 +1002,7 @@ class TunnelManager(BaseTunnelManager):
...
@@ -987,7 +1002,7 @@ class TunnelManager(BaseTunnelManager):
break
break
def
killAll
(
self
):
def
killAll
(
self
):
for
prefix
in
self
.
_connection_dict
.
keys
(
):
for
prefix
in
list
(
self
.
_connection_dict
):
self
.
_kill
(
prefix
)
self
.
_kill
(
prefix
)
def
handleClientEvent
(
self
):
def
handleClientEvent
(
self
):
...
@@ -999,7 +1014,7 @@ class TunnelManager(BaseTunnelManager):
...
@@ -999,7 +1014,7 @@ class TunnelManager(BaseTunnelManager):
if
c
and
c
.
time
<
float
(
time
):
if
c
and
c
.
time
<
float
(
time
):
try
:
try
:
c
.
connected
(
serial
)
c
.
connected
(
serial
)
except
(
KeyError
,
TypeError
)
,
e
:
except
(
KeyError
,
TypeError
)
as
e
:
logging
.
error
(
"%s (route_up %s)"
,
e
,
common_name
)
logging
.
error
(
"%s (route_up %s)"
,
e
,
common_name
)
else
:
else
:
logging
.
info
(
"ignore route_up notification for %s %r"
,
logging
.
info
(
"ignore route_up notification for %s %r"
,
...
@@ -1010,10 +1025,10 @@ class TunnelManager(BaseTunnelManager):
...
@@ -1010,10 +1025,10 @@ class TunnelManager(BaseTunnelManager):
if
self
.
cache
.
same_country
:
if
self
.
cache
.
same_country
:
address
=
self
.
_updateCountry
(
address
)
address
=
self
.
_updateCountry
(
address
)
self
.
_address
[
family
]
=
utils
.
dump_address
(
address
)
self
.
_address
[
family
]
=
utils
.
dump_address
(
address
)
self
.
cache
.
my_address
=
';'
.
join
(
self
.
_address
.
iter
values
())
self
.
cache
.
my_address
=
';'
.
join
(
self
.
_address
.
values
())
def
broadcastNewVersion
(
self
):
def
broadcastNewVersion
(
self
):
self
.
_babel_dump_new_version
()
self
.
_babel_dump_new_version
()
for
prefix
,
c
in
self
.
_connection_dict
.
items
(
):
for
prefix
,
c
in
list
(
self
.
_connection_dict
.
items
()
):
if
c
.
serial
in
self
.
cache
.
crl
:
if
c
.
serial
in
self
.
cache
.
crl
:
self
.
_kill
(
prefix
)
self
.
_kill
(
prefix
)
re6st/upnpigd.py
View file @
e80ca878
...
@@ -7,7 +7,7 @@ class UPnPException(Exception):
...
@@ -7,7 +7,7 @@ class UPnPException(Exception):
pass
pass
class
Forwarder
(
object
)
:
class
Forwarder
:
"""
"""
External port is chosen randomly between 32768 & 49151 included.
External port is chosen randomly between 32768 & 49151 included.
"""
"""
...
@@ -17,7 +17,7 @@ class Forwarder(object):
...
@@ -17,7 +17,7 @@ class Forwarder(object):
_lcg_n
=
0
_lcg_n
=
0
@
classmethod
@
classmethod
def
_getExternalPort
(
cls
):
def
_getExternalPort
(
cls
)
->
int
:
# Since _refresh() does not test all ports in a row, we prefer to
# Since _refresh() does not test all ports in a row, we prefer to
# return random ports to maximize the chance to find a free port.
# return random ports to maximize the chance to find a free port.
# A linear congruential generator should be random enough, without
# A linear congruential generator should be random enough, without
...
@@ -35,12 +35,12 @@ class Forwarder(object):
...
@@ -35,12 +35,12 @@ class Forwarder(object):
self
.
_u
.
discoverdelay
=
200
self
.
_u
.
discoverdelay
=
200
self
.
_rules
=
[]
self
.
_rules
=
[]
def
__getattr__
(
self
,
name
):
def
__getattr__
(
self
,
name
:
str
):
wrapped
=
getattr
(
self
.
_u
,
name
)
wrapped
=
getattr
(
self
.
_u
,
name
)
def
wrapper
(
*
args
,
**
kw
):
def
wrapper
(
*
args
,
**
kw
):
try
:
try
:
return
wrapped
(
*
args
,
**
kw
)
return
wrapped
(
*
args
,
**
kw
)
except
Exception
,
e
:
except
Exception
as
e
:
raise
UPnPException
(
str
(
e
))
raise
UPnPException
(
str
(
e
))
return
wraps
(
wrapped
)(
wrapper
)
return
wraps
(
wrapped
)(
wrapper
)
...
@@ -68,14 +68,14 @@ class Forwarder(object):
...
@@ -68,14 +68,14 @@ class Forwarder(object):
else
:
else
:
try
:
try
:
return
self
.
_refresh
()
return
self
.
_refresh
()
except
UPnPException
,
e
:
except
UPnPException
as
e
:
logging
.
debug
(
"UPnP failure"
,
exc_info
=
1
)
logging
.
debug
(
"UPnP failure"
,
exc_info
=
True
)
self
.
clear
()
self
.
clear
()
try
:
try
:
self
.
discover
()
self
.
discover
()
self
.
selectigd
()
self
.
selectigd
()
return
self
.
_refresh
()
return
self
.
_refresh
()
except
UPnPException
,
e
:
except
UPnPException
as
e
:
self
.
next_refresh
=
self
.
_next_retry
=
time
.
time
()
+
60
self
.
next_refresh
=
self
.
_next_retry
=
time
.
time
()
+
60
logging
.
info
(
str
(
e
))
logging
.
info
(
str
(
e
))
self
.
clear
()
self
.
clear
()
...
@@ -109,7 +109,7 @@ class Forwarder(object):
...
@@ -109,7 +109,7 @@ class Forwarder(object):
try
:
try
:
self
.
addportmapping
(
port
,
*
args
)
self
.
addportmapping
(
port
,
*
args
)
break
break
except
UPnPException
,
e
:
except
UPnPException
as
e
:
if
str
(
e
)
!=
'ConflictInMappingEntry'
:
if
str
(
e
)
!=
'ConflictInMappingEntry'
:
raise
raise
port
=
None
port
=
None
...
...
re6st/utils.py
View file @
e80ca878
import
argparse
,
errno
,
fcntl
,
hashlib
,
logging
,
os
,
select
as
_select
import
argparse
,
errno
,
fcntl
,
hashlib
,
logging
,
os
,
select
as
_select
import
shlex
,
signal
,
socket
,
sqlite3
,
struct
,
subprocess
import
shlex
,
signal
,
socket
,
sqlite3
,
struct
,
subprocess
import
sys
,
textwrap
,
threading
,
time
,
traceback
import
sys
,
textwrap
,
threading
,
time
,
traceback
from
collections.abc
import
Iterator
,
Mapping
# PY3: It will be even better to use Popen(pass_fds=...),
HMAC_LEN
=
len
(
hashlib
.
sha1
(
b''
).
digest
())
# and then socket.SOCK_CLOEXEC will be useless.
# (We already follow the good practice that consists in not
# relying on the GC for the closing of file descriptors.)
socket
.
SOCK_CLOEXEC
=
0x80000
HMAC_LEN
=
len
(
hashlib
.
sha1
(
''
).
digest
())
class
ReexecException
(
Exception
):
class
ReexecException
(
Exception
):
pass
pass
...
@@ -37,15 +32,15 @@ class FileHandler(logging.FileHandler):
...
@@ -37,15 +32,15 @@ class FileHandler(logging.FileHandler):
finally
:
finally
:
self
.
lock
.
release
()
self
.
lock
.
release
()
# In the rare case _reopen is set just before the lock was released
# In the rare case _reopen is set just before the lock was released
if
self
.
_reopen
and
self
.
lock
.
acquire
(
0
):
if
self
.
_reopen
and
self
.
lock
.
acquire
(
False
):
self
.
release
()
self
.
release
()
def
async_reopen
(
self
,
*
_
):
def
async_reopen
(
self
,
*
_
):
self
.
_reopen
=
True
self
.
_reopen
=
True
if
self
.
lock
.
acquire
(
0
):
if
self
.
lock
.
acquire
(
False
):
self
.
release
()
self
.
release
()
def
setupLog
(
log_level
,
filenam
e
=
None
,
**
kw
):
def
setupLog
(
log_level
:
int
,
filename
:
str
|
Non
e
=
None
,
**
kw
):
if
log_level
and
filename
:
if
log_level
and
filename
:
makedirs
(
os
.
path
.
dirname
(
filename
))
makedirs
(
os
.
path
.
dirname
(
filename
))
handler
=
FileHandler
(
filename
)
handler
=
FileHandler
(
filename
)
...
@@ -119,7 +114,7 @@ class ArgParser(argparse.ArgumentParser):
...
@@ -119,7 +114,7 @@ class ArgParser(argparse.ArgumentParser):
ca /etc/re6stnet/ca.crt"""
,
**
kw
)
ca /etc/re6stnet/ca.crt"""
,
**
kw
)
class
exit
(
object
)
:
class
exit
:
status
=
None
status
=
None
...
@@ -150,7 +145,7 @@ class exit(object):
...
@@ -150,7 +145,7 @@ class exit(object):
def
handler
(
*
args
):
def
handler
(
*
args
):
if
self
.
status
is
None
:
if
self
.
status
is
None
:
self
.
status
=
status
self
.
status
=
status
if
self
.
acquire
(
0
):
if
self
.
acquire
(
False
):
self
.
release
()
self
.
release
()
for
sig
in
sigs
:
for
sig
in
sigs
:
signal
.
signal
(
sig
,
handler
)
signal
.
signal
(
sig
,
handler
)
...
@@ -164,7 +159,7 @@ class Popen(subprocess.Popen):
...
@@ -164,7 +159,7 @@ class Popen(subprocess.Popen):
self
.
_args
=
tuple
(
args
[
0
]
if
args
else
kw
[
'args'
])
self
.
_args
=
tuple
(
args
[
0
]
if
args
else
kw
[
'args'
])
try
:
try
:
super
(
Popen
,
self
).
__init__
(
*
args
,
**
kw
)
super
(
Popen
,
self
).
__init__
(
*
args
,
**
kw
)
except
OSError
,
e
:
except
OSError
as
e
:
if
e
.
errno
!=
errno
.
ENOMEM
:
if
e
.
errno
!=
errno
.
ENOMEM
:
raise
raise
self
.
returncode
=
-
1
self
.
returncode
=
-
1
...
@@ -179,9 +174,9 @@ class Popen(subprocess.Popen):
...
@@ -179,9 +174,9 @@ class Popen(subprocess.Popen):
self
.
terminate
()
self
.
terminate
()
t
=
threading
.
Timer
(
5
,
self
.
kill
)
t
=
threading
.
Timer
(
5
,
self
.
kill
)
t
.
start
()
t
.
start
()
# PY3: use waitid(WNOWAIT) and call self.poll() after t.cancel()
r
=
os
.
waitid
(
os
.
P_PID
,
self
.
pid
,
os
.
WNOWAIT
)
r
=
self
.
wait
()
t
.
cancel
()
t
.
cancel
()
self
.
poll
()
return
r
return
r
...
@@ -189,7 +184,7 @@ def setCloexec(fd):
...
@@ -189,7 +184,7 @@ def setCloexec(fd):
flags
=
fcntl
.
fcntl
(
fd
,
fcntl
.
F_GETFD
)
flags
=
fcntl
.
fcntl
(
fd
,
fcntl
.
F_GETFD
)
fcntl
.
fcntl
(
fd
,
fcntl
.
F_SETFD
,
flags
|
fcntl
.
FD_CLOEXEC
)
fcntl
.
fcntl
(
fd
,
fcntl
.
F_SETFD
,
flags
|
fcntl
.
FD_CLOEXEC
)
def
select
(
R
,
W
,
T
):
def
select
(
R
:
Mapping
,
W
:
Mapping
,
T
):
try
:
try
:
r
,
w
,
_
=
_select
.
select
(
R
,
W
,
(),
r
,
w
,
_
=
_select
.
select
(
R
,
W
,
(),
max
(
0
,
min
(
T
)[
0
]
-
time
.
time
())
if
T
else
None
)
max
(
0
,
min
(
T
)[
0
]
-
time
.
time
())
if
T
else
None
)
...
@@ -209,19 +204,19 @@ def select(R, W, T):
...
@@ -209,19 +204,19 @@ def select(R, W, T):
def
makedirs
(
*
args
):
def
makedirs
(
*
args
):
try
:
try
:
os
.
makedirs
(
*
args
)
os
.
makedirs
(
*
args
)
except
OSError
,
e
:
except
OSError
as
e
:
if
e
.
errno
!=
errno
.
EEXIST
:
if
e
.
errno
!=
errno
.
EEXIST
:
raise
raise
def
binFromIp
(
ip
)
:
def
binFromIp
(
ip
:
str
)
->
str
:
return
binFromRawIp
(
socket
.
inet_pton
(
socket
.
AF_INET6
,
ip
))
return
binFromRawIp
(
socket
.
inet_pton
(
socket
.
AF_INET6
,
ip
))
def
binFromRawIp
(
ip
)
:
def
binFromRawIp
(
ip
:
bytes
)
->
str
:
ip1
,
ip2
=
struct
.
unpack
(
'>QQ'
,
ip
)
ip1
,
ip2
=
struct
.
unpack
(
'>QQ'
,
ip
)
return
bin
(
ip1
)[
2
:].
rjust
(
64
,
'0'
)
+
bin
(
ip2
)[
2
:].
rjust
(
64
,
'0'
)
return
bin
(
ip1
)[
2
:].
rjust
(
64
,
'0'
)
+
bin
(
ip2
)[
2
:].
rjust
(
64
,
'0'
)
def
ipFromBin
(
ip
,
suffix
=
''
)
:
def
ipFromBin
(
ip
:
str
,
suffix
=
''
)
->
str
:
suffix_len
=
128
-
len
(
ip
)
suffix_len
=
128
-
len
(
ip
)
if
suffix_len
>
0
:
if
suffix_len
>
0
:
ip
+=
suffix
.
rjust
(
suffix_len
,
'0'
)
ip
+=
suffix
.
rjust
(
suffix_len
,
'0'
)
...
@@ -230,30 +225,32 @@ def ipFromBin(ip, suffix=''):
...
@@ -230,30 +225,32 @@ def ipFromBin(ip, suffix=''):
return
socket
.
inet_ntop
(
socket
.
AF_INET6
,
return
socket
.
inet_ntop
(
socket
.
AF_INET6
,
struct
.
pack
(
'>QQ'
,
int
(
ip
[:
64
],
2
),
int
(
ip
[
64
:],
2
)))
struct
.
pack
(
'>QQ'
,
int
(
ip
[:
64
],
2
),
int
(
ip
[
64
:],
2
)))
def
dump_address
(
address
)
:
def
dump_address
(
address
:
str
)
->
str
:
return
';'
.
join
(
map
(
','
.
join
,
address
))
return
';'
.
join
(
map
(
','
.
join
,
address
))
# Yield ip, port, protocol, and country if it is in the address
# Yield ip, port, protocol, and country if it is in the address
def
parse_address
(
address_list
)
:
def
parse_address
(
address_list
:
str
)
->
Iterator
[
tuple
[
str
,
str
,
str
,
str
]]
:
for
address
in
address_list
.
split
(
';'
):
for
address
in
address_list
.
split
(
';'
):
try
:
try
:
a
=
address
.
split
(
','
)
a
=
address
.
split
(
','
)
int
(
a
[
1
])
# Check if port is an int
int
(
a
[
1
])
# Check if port is an int
yield
tuple
(
a
[:
4
])
yield
tuple
(
a
[:
4
])
except
ValueError
,
e
:
except
ValueError
as
e
:
logging
.
warning
(
"Failed to parse node address %r (%s)"
,
logging
.
warning
(
"Failed to parse node address %r (%s)"
,
address
,
e
)
address
,
e
)
def
binFromSubnet
(
subnet
)
:
def
binFromSubnet
(
subnet
:
str
)
->
str
:
p
,
l
=
subnet
.
split
(
'/'
)
p
,
l
=
subnet
.
split
(
'/'
)
return
bin
(
int
(
p
))[
2
:].
rjust
(
int
(
l
),
'0'
)
return
bin
(
int
(
p
))[
2
:].
rjust
(
int
(
l
),
'0'
)
def
newHmacSecret
():
def
_
newHmacSecret
():
from
random
import
getrandbits
as
g
from
random
import
getrandbits
as
g
pack
=
struct
.
Struct
(
">QQI"
).
pack
pack
=
struct
.
Struct
(
">QQI"
).
pack
assert
len
(
pack
(
0
,
0
,
0
))
==
HMAC_LEN
assert
len
(
pack
(
0
,
0
,
0
))
==
HMAC_LEN
# A closure is built to avoid rebuilding the `pack` function at each call.
return
lambda
x
=
None
:
pack
(
g
(
64
)
if
x
is
None
else
x
,
g
(
64
),
g
(
32
))
return
lambda
x
=
None
:
pack
(
g
(
64
)
if
x
is
None
else
x
,
g
(
64
),
g
(
32
))
newHmacSecret
=
newHmacSecret
()
newHmacSecret
=
_newHmacSecret
()
# https://github.com/python/mypy/issues/1174
### Integer serialization
### Integer serialization
# - supports values from 0 to 0x202020202020201f
# - supports values from 0 to 0x202020202020201f
...
@@ -261,21 +258,21 @@ newHmacSecret = newHmacSecret()
...
@@ -261,21 +258,21 @@ newHmacSecret = newHmacSecret()
# - there's always a unique way to encode a value
# - there's always a unique way to encode a value
# - the 3 first bits code the number of bytes
# - the 3 first bits code the number of bytes
def
packInteger
(
i
)
:
def
packInteger
(
i
:
int
)
->
bytes
:
for
n
in
x
range
(
8
):
for
n
in
range
(
8
):
x
=
32
<<
8
*
n
x
=
32
<<
8
*
n
if
i
<
x
:
if
i
<
x
:
return
struct
.
pack
(
"!Q"
,
i
+
n
*
x
)[
7
-
n
:]
return
struct
.
pack
(
"!Q"
,
i
+
n
*
x
)[
7
-
n
:]
i
-=
x
i
-=
x
raise
OverflowError
raise
OverflowError
def
unpackInteger
(
x
)
:
def
unpackInteger
(
x
:
bytes
)
->
tuple
[
int
,
int
]
|
None
:
n
=
ord
(
x
[
0
])
>>
5
n
=
x
[
0
]
>>
5
try
:
try
:
i
,
=
struct
.
unpack
(
"!Q"
,
'
\
0
'
*
(
7
-
n
)
+
x
[:
n
+
1
])
i
,
=
struct
.
unpack
(
"!Q"
,
b
'
\
0
'
*
(
7
-
n
)
+
x
[:
n
+
1
])
except
struct
.
error
:
except
struct
.
error
:
return
return
return
sum
((
32
<<
8
*
i
for
i
in
x
range
(
n
)),
return
sum
((
32
<<
8
*
i
for
i
in
range
(
n
)),
i
-
(
n
*
32
<<
8
*
n
)),
n
+
1
i
-
(
n
*
32
<<
8
*
n
)),
n
+
1
###
###
...
...
re6st/version.py
View file @
e80ca878
...
@@ -40,4 +40,4 @@ protocol = 8
...
@@ -40,4 +40,4 @@ protocol = 8
min_protocol
=
1
min_protocol
=
1
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
print
version
print
(
version
)
re6st/x509.py
View file @
e80ca878
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import
calendar
,
hashlib
,
hmac
,
logging
,
os
,
struct
,
subprocess
,
threading
,
time
import
calendar
,
hashlib
,
hmac
,
logging
,
os
,
struct
,
subprocess
,
threading
,
time
from
typing
import
Callable
,
Any
from
OpenSSL
import
crypto
from
OpenSSL
import
crypto
from
cryptography.hazmat.primitives
import
hashes
from
cryptography.hazmat.primitives.asymmetric
import
padding
from
cryptography.hazmat.primitives.serialization
import
load_pem_private_key
from
cryptography.x509
import
load_pem_x509_certificate
from
.
import
utils
from
.
import
utils
from
.version
import
protocol
from
.version
import
protocol
def
newHmacSecret
():
def
newHmacSecret
()
->
bytes
:
return
utils
.
newHmacSecret
(
int
(
time
.
time
()
*
1000000
))
return
utils
.
newHmacSecret
(
int
(
time
.
time
()
*
1000000
))
def
networkFromCa
(
ca
):
def
networkFromCa
(
ca
:
crypto
.
X509
)
->
str
:
# TODO: will be ca.serial_number after migration to cryptography
return
bin
(
ca
.
get_serial_number
())[
3
:]
return
bin
(
ca
.
get_serial_number
())[
3
:]
def
subnetFromCert
(
cert
)
:
def
subnetFromCert
(
cert
:
crypto
.
X509
)
->
str
:
return
cert
.
get_subject
().
CN
return
cert
.
get_subject
().
CN
def
notBefore
(
cert
):
def
notBefore
(
cert
:
crypto
.
X509
)
->
int
:
return
calendar
.
timegm
(
time
.
strptime
(
cert
.
get_notBefore
(),
'%Y%m%d%H%M%SZ'
))
return
calendar
.
timegm
(
time
.
strptime
(
cert
.
get_notBefore
().
decode
(),
'%Y%m%d%H%M%SZ'
))
def
notAfter
(
cert
):
def
notAfter
(
cert
:
crypto
.
X509
)
->
int
:
return
calendar
.
timegm
(
time
.
strptime
(
cert
.
get_notAfter
(),
'%Y%m%d%H%M%SZ'
))
return
calendar
.
timegm
(
time
.
strptime
(
cert
.
get_notAfter
().
decode
(),
'%Y%m%d%H%M%SZ'
))
def
openssl
(
*
args
)
:
def
openssl
(
*
args
:
str
,
fds
=
[])
->
utils
.
Popen
:
return
utils
.
Popen
((
'openssl'
,)
+
args
,
return
utils
.
Popen
((
'openssl'
,)
+
args
,
stdin
=
subprocess
.
PIPE
,
stdin
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
stderr
=
subprocess
.
PIPE
,
pass_fds
=
fds
)
def
encrypt
(
cert
,
data
)
:
def
encrypt
(
cert
:
bytes
,
data
:
bytes
)
->
bytes
:
r
,
w
=
os
.
pipe
()
r
,
w
=
os
.
pipe
()
try
:
try
:
threading
.
Thread
(
target
=
os
.
write
,
args
=
(
w
,
cert
)).
start
()
threading
.
Thread
(
target
=
os
.
write
,
args
=
(
w
,
cert
)).
start
()
p
=
openssl
(
'rsautl'
,
'-encrypt'
,
'-certin'
,
p
=
openssl
(
'rsautl'
,
'-encrypt'
,
'-certin'
,
'-inkey'
,
'/proc/self/fd/%u'
%
r
)
'-inkey'
,
'/proc/self/fd/%u'
%
r
,
fds
=
[
r
]
)
out
,
err
=
p
.
communicate
(
data
)
out
,
err
=
p
.
communicate
(
data
)
finally
:
finally
:
os
.
close
(
r
)
os
.
close
(
r
)
...
@@ -39,10 +49,12 @@ def encrypt(cert, data):
...
@@ -39,10 +49,12 @@ def encrypt(cert, data):
raise
subprocess
.
CalledProcessError
(
p
.
returncode
,
'openssl'
,
err
)
raise
subprocess
.
CalledProcessError
(
p
.
returncode
,
'openssl'
,
err
)
return
out
return
out
def
fingerprint
(
cert
,
alg
=
'sha1'
):
def
fingerprint
(
cert
:
crypto
.
X509
,
alg
=
'sha1'
):
return
hashlib
.
new
(
alg
,
crypto
.
dump_certificate
(
crypto
.
FILETYPE_ASN1
,
cert
))
return
hashlib
.
new
(
alg
,
crypto
.
dump_certificate
(
crypto
.
FILETYPE_ASN1
,
cert
))
def
maybe_renew
(
path
,
cert
,
info
,
renew
,
force
=
False
):
def
maybe_renew
(
path
:
str
,
cert
:
crypto
.
X509
,
info
:
str
,
renew
:
Callable
[[],
bytes
],
force
=
False
)
->
tuple
[
crypto
.
X509
,
int
]:
from
.registry
import
RENEW_PERIOD
from
.registry
import
RENEW_PERIOD
while
True
:
while
True
:
if
force
:
if
force
:
...
@@ -62,7 +74,7 @@ def maybe_renew(path, cert, info, renew, force=False):
...
@@ -62,7 +74,7 @@ def maybe_renew(path, cert, info, renew, force=False):
exc_info
=
1
exc_info
=
1
break
break
new_path
=
path
+
'.new'
new_path
=
path
+
'.new'
with
open
(
new_path
,
'w'
)
as
f
:
with
open
(
new_path
,
'w
b
'
)
as
f
:
f
.
write
(
pem
)
f
.
write
(
pem
)
try
:
try
:
s
=
os
.
stat
(
path
)
s
=
os
.
stat
(
path
)
...
@@ -84,39 +96,44 @@ class NewSessionError(Exception):
...
@@ -84,39 +96,44 @@ class NewSessionError(Exception):
pass
pass
class
Cert
(
object
)
:
class
Cert
:
def
__init__
(
self
,
ca
,
key
,
cert
=
None
):
def
__init__
(
self
,
ca
:
str
,
key
:
str
,
cert
:
str
|
None
=
None
):
self
.
ca_path
=
ca
self
.
ca_path
=
ca
self
.
cert_path
=
cert
self
.
cert_path
=
cert
self
.
key_path
=
key
self
.
key_path
=
key
with
open
(
ca
)
as
f
:
# TODO: finish migration from old OpenSSL module to cryptography
self
.
ca
=
crypto
.
load_certificate
(
crypto
.
FILETYPE_PEM
,
f
.
read
())
with
open
(
ca
,
"rb"
)
as
f
:
with
open
(
key
)
as
f
:
ca_pem
=
f
.
read
()
self
.
key
=
crypto
.
load_privatekey
(
crypto
.
FILETYPE_PEM
,
f
.
read
())
self
.
ca
=
crypto
.
load_certificate
(
crypto
.
FILETYPE_PEM
,
ca_pem
)
self
.
ca_crypto
=
load_pem_x509_certificate
(
ca_pem
)
with
open
(
key
,
"rb"
)
as
f
:
key_pem
=
f
.
read
()
self
.
key
=
crypto
.
load_privatekey
(
crypto
.
FILETYPE_PEM
,
key_pem
)
self
.
key_crypto
=
load_pem_private_key
(
key_pem
,
password
=
None
)
if
cert
:
if
cert
:
with
open
(
cert
)
as
f
:
with
open
(
cert
)
as
f
:
self
.
cert
=
self
.
loadVerify
(
f
.
read
())
self
.
cert
=
self
.
loadVerify
(
f
.
read
()
.
encode
()
)
@
property
@
property
def
prefix
(
self
):
def
prefix
(
self
)
->
str
:
return
utils
.
binFromSubnet
(
subnetFromCert
(
self
.
cert
))
return
utils
.
binFromSubnet
(
subnetFromCert
(
self
.
cert
))
@
property
@
property
def
network
(
self
):
def
network
(
self
)
->
str
:
return
networkFromCa
(
self
.
ca
)
return
networkFromCa
(
self
.
ca
)
@
property
@
property
def
subject_serial
(
self
):
def
subject_serial
(
self
)
->
int
:
return
int
(
self
.
cert
.
get_subject
().
serialNumber
)
return
int
(
self
.
cert
.
get_subject
().
serialNumber
)
@
property
@
property
def
openvpn_args
(
self
):
def
openvpn_args
(
self
)
->
tuple
[
str
,
...]
:
return
(
'--ca'
,
self
.
ca_path
,
return
(
'--ca'
,
self
.
ca_path
,
'--cert'
,
self
.
cert_path
,
'--cert'
,
self
.
cert_path
,
'--key'
,
self
.
key_path
)
'--key'
,
self
.
key_path
)
def
maybeRenew
(
self
,
registry
,
crl
):
def
maybeRenew
(
self
,
registry
,
crl
)
->
int
:
self
.
cert
,
next_renew
=
maybe_renew
(
self
.
cert_path
,
self
.
cert
,
self
.
cert
,
next_renew
=
maybe_renew
(
self
.
cert_path
,
self
.
cert
,
"Certificate"
,
lambda
:
registry
.
renewCertificate
(
self
.
prefix
),
"Certificate"
,
lambda
:
registry
.
renewCertificate
(
self
.
prefix
),
self
.
cert
.
get_serial_number
()
in
crl
)
self
.
cert
.
get_serial_number
()
in
crl
)
...
@@ -143,21 +160,31 @@ class Cert(object):
...
@@ -143,21 +160,31 @@ class Cert(object):
"error running openssl, assuming cert is invalid"
)
"error running openssl, assuming cert is invalid"
)
# BBB: With old versions of openssl, detailed
# BBB: With old versions of openssl, detailed
# error is printed to standard output.
# error is printed to standard output.
for
err
in
err
,
out
:
for
stream
in
err
,
out
:
for
x
in
err
.
splitlines
():
for
x
in
stream
.
decode
(
errors
=
'replace'
)
.
splitlines
():
if
x
.
startswith
(
'error '
):
if
x
.
startswith
(
'error '
):
x
,
msg
=
x
.
split
(
':'
,
1
)
x
,
msg
=
x
.
split
(
':'
,
1
)
_
,
code
,
_
,
depth
,
_
=
x
.
split
(
None
,
4
)
_
,
code
,
_
,
depth
,
_
=
x
.
split
(
None
,
4
)
raise
VerifyError
(
int
(
code
),
int
(
depth
),
msg
.
strip
())
raise
VerifyError
(
int
(
code
),
int
(
depth
),
msg
.
strip
())
return
r
return
r
def
verify
(
self
,
sign
,
data
):
def
verify
(
self
,
sign
:
bytes
,
data
:
bytes
):
crypto
.
verify
(
self
.
ca
,
sign
,
data
,
'sha512'
)
pub_key
=
self
.
ca_crypto
.
public_key
()
pub_key
.
verify
(
def
sign
(
self
,
data
):
sign
,
return
crypto
.
sign
(
self
.
key
,
data
,
'sha512'
)
data
,
padding
.
PKCS1v15
(),
def
decrypt
(
self
,
data
):
hashes
.
SHA512
()
)
def
sign
(
self
,
data
:
bytes
)
->
bytes
:
return
self
.
key_crypto
.
sign
(
data
,
padding
.
PKCS1v15
(),
hashes
.
SHA512
()
)
def
decrypt
(
self
,
data
:
bytes
)
->
bytes
:
p
=
openssl
(
'rsautl'
,
'-decrypt'
,
'-inkey'
,
self
.
key_path
)
p
=
openssl
(
'rsautl'
,
'-decrypt'
,
'-inkey'
,
self
.
key_path
)
out
,
err
=
p
.
communicate
(
data
)
out
,
err
=
p
.
communicate
(
data
)
if
p
.
returncode
:
if
p
.
returncode
:
...
@@ -166,7 +193,7 @@ class Cert(object):
...
@@ -166,7 +193,7 @@ class Cert(object):
def
verifyVersion
(
self
,
version
):
def
verifyVersion
(
self
,
version
):
try
:
try
:
n
=
1
+
(
ord
(
version
[
0
])
>>
5
)
n
=
1
+
(
version
[
0
]
>>
5
)
self
.
verify
(
version
[
n
:],
version
[:
n
])
self
.
verify
(
version
[
n
:],
version
[:
n
])
except
(
IndexError
,
crypto
.
Error
):
except
(
IndexError
,
crypto
.
Error
):
raise
VerifyError
(
None
,
None
,
'invalid network version'
)
raise
VerifyError
(
None
,
None
,
'invalid network version'
)
...
@@ -175,7 +202,7 @@ class Cert(object):
...
@@ -175,7 +202,7 @@ class Cert(object):
PACKED_PROTOCOL
=
utils
.
packInteger
(
protocol
)
PACKED_PROTOCOL
=
utils
.
packInteger
(
protocol
)
class
Peer
(
object
)
:
class
Peer
:
"""
"""
UDP: A ─────────────────────────────────────────────> B
UDP: A ─────────────────────────────────────────────> B
...
@@ -206,9 +233,10 @@ class Peer(object):
...
@@ -206,9 +233,10 @@ class Peer(object):
_key
=
newHmacSecret
()
_key
=
newHmacSecret
()
serial
=
None
serial
=
None
stop_date
=
float
(
'inf'
)
stop_date
=
float
(
'inf'
)
version
=
''
version
=
b''
cert
:
crypto
.
X509
def
__init__
(
self
,
prefix
):
def
__init__
(
self
,
prefix
:
str
):
self
.
prefix
=
prefix
self
.
prefix
=
prefix
@
property
@
property
...
@@ -224,35 +252,35 @@ class Peer(object):
...
@@ -224,35 +252,35 @@ class Peer(object):
def
__lt__
(
self
,
other
):
def
__lt__
(
self
,
other
):
return
self
.
prefix
<
(
other
if
type
(
other
)
is
str
else
other
.
prefix
)
return
self
.
prefix
<
(
other
if
type
(
other
)
is
str
else
other
.
prefix
)
def
hello0
(
self
,
cert
)
:
def
hello0
(
self
,
cert
:
crypto
.
X509
)
->
bytes
:
if
self
.
_hello
<
time
.
time
():
if
self
.
_hello
<
time
.
time
():
try
:
try
:
# Always assume peer is not old, in case it has just upgraded,
# Always assume peer is not old, in case it has just upgraded,
# else we would be stuck with the old protocol.
# else we would be stuck with the old protocol.
msg
=
(
'
\
0
\
0
\
0
\
1
'
msg
=
(
b
'
\
0
\
0
\
0
\
1
'
+
PACKED_PROTOCOL
+
PACKED_PROTOCOL
+
fingerprint
(
self
.
cert
).
digest
())
+
fingerprint
(
self
.
cert
).
digest
())
except
AttributeError
:
except
AttributeError
:
msg
=
'
\
0
\
0
\
0
\
0
'
msg
=
b
'
\
0
\
0
\
0
\
0
'
return
msg
+
crypto
.
dump_certificate
(
crypto
.
FILETYPE_ASN1
,
cert
)
return
msg
+
crypto
.
dump_certificate
(
crypto
.
FILETYPE_ASN1
,
cert
)
def
hello0Sent
(
self
):
def
hello0Sent
(
self
):
self
.
_hello
=
time
.
time
()
+
60
self
.
_hello
=
time
.
time
()
+
60
def
hello
(
self
,
cert
,
protocol
)
:
def
hello
(
self
,
cert
:
Cert
,
protocol
:
int
)
->
bytes
:
key
=
self
.
_key
=
newHmacSecret
()
key
=
self
.
_key
=
newHmacSecret
()
h
=
encrypt
(
crypto
.
dump_certificate
(
crypto
.
FILETYPE_PEM
,
self
.
cert
),
h
=
encrypt
(
crypto
.
dump_certificate
(
crypto
.
FILETYPE_PEM
,
self
.
cert
),
key
)
key
)
self
.
_i
=
self
.
_j
=
2
self
.
_i
=
self
.
_j
=
2
self
.
_last
=
0
self
.
_last
=
0
self
.
protocol
=
protocol
self
.
protocol
=
protocol
return
''
.
join
((
'
\
0
\
0
\
0
\
2
'
,
PACKED_PROTOCOL
if
protocol
else
''
,
return
b''
.
join
((
b'
\
0
\
0
\
0
\
2
'
,
PACKED_PROTOCOL
if
protocol
else
b
''
,
h
,
cert
.
sign
(
h
)))
h
,
cert
.
sign
(
h
)))
def
_hmac
(
self
,
msg
)
:
def
_hmac
(
self
,
msg
:
bytes
)
->
bytes
:
return
hmac
.
HMAC
(
self
.
_key
,
msg
,
hashlib
.
sha1
).
digest
()
return
hmac
.
HMAC
(
self
.
_key
,
msg
,
hashlib
.
sha1
).
digest
()
def
newSession
(
self
,
key
,
protocol
):
def
newSession
(
self
,
key
:
bytes
,
protocol
:
int
):
if
key
<=
self
.
_key
:
if
key
<=
self
.
_key
:
raise
NewSessionError
(
self
.
_key
,
key
)
raise
NewSessionError
(
self
.
_key
,
key
)
self
.
_key
=
key
self
.
_key
=
key
...
@@ -260,12 +288,13 @@ class Peer(object):
...
@@ -260,12 +288,13 @@ class Peer(object):
self
.
_last
=
None
self
.
_last
=
None
self
.
protocol
=
protocol
self
.
protocol
=
protocol
def
verify
(
self
,
sign
,
data
):
def
verify
(
self
,
sign
:
bytes
,
data
:
bytes
):
crypto
.
verify
(
self
.
cert
,
sign
,
data
,
'sha512'
)
crypto
.
verify
(
self
.
cert
,
sign
,
data
,
'sha512'
)
seqno_struct
=
struct
.
Struct
(
"!L"
)
seqno_struct
=
struct
.
Struct
(
"!L"
)
def
decode
(
self
,
msg
,
_unpack
=
seqno_struct
.
unpack
):
def
decode
(
self
,
msg
:
bytes
,
_unpack
=
seqno_struct
.
unpack
)
\
->
tuple
[
int
,
bytes
,
int
|
None
]
|
bytes
:
seqno
,
=
_unpack
(
msg
[:
4
])
seqno
,
=
_unpack
(
msg
[:
4
])
if
seqno
<=
2
:
if
seqno
<=
2
:
msg
=
msg
[
4
:]
msg
=
msg
[
4
:]
...
@@ -281,8 +310,10 @@ class Peer(object):
...
@@ -281,8 +310,10 @@ class Peer(object):
self
.
_i
=
seqno
self
.
_i
=
seqno
return
msg
[
4
:
i
]
return
msg
[
4
:
i
]
def
encode
(
self
,
msg
,
_pack
=
seqno_struct
.
pack
)
:
def
encode
(
self
,
msg
:
str
|
bytes
,
_pack
=
seqno_struct
.
pack
)
->
bytes
:
self
.
_j
+=
1
self
.
_j
+=
1
if
type
(
msg
)
is
str
:
msg
=
msg
.
encode
()
msg
=
_pack
(
self
.
_j
)
+
msg
msg
=
_pack
(
self
.
_j
)
+
msg
return
msg
+
self
.
_hmac
(
msg
)
return
msg
+
self
.
_hmac
(
msg
)
...
...
setup.py
View file @
e80ca878
...
@@ -7,21 +7,23 @@ from setuptools.command import sdist as _sdist, build_py as _build_py
...
@@ -7,21 +7,23 @@ from setuptools.command import sdist as _sdist, build_py as _build_py
from
distutils
import
log
from
distutils
import
log
version
=
{
"__file__"
:
"re6st/version.py"
}
version
=
{
"__file__"
:
"re6st/version.py"
}
execfile
(
version
[
"__file__"
],
version
)
with
open
(
version
[
"__file__"
])
as
f
:
code
=
compile
(
f
.
read
(),
version
[
"__file__"
],
'exec'
)
exec
(
code
,
version
)
def
copy_file
(
self
,
infile
,
outfile
,
*
args
,
**
kw
):
def
copy_file
(
self
,
infile
,
outfile
,
*
args
,
**
kw
):
if
infile
==
version
[
"__file__"
]:
if
infile
==
version
[
"__file__"
]:
if
not
self
.
dry_run
:
if
not
self
.
dry_run
:
log
.
info
(
"generating %s -> %s"
,
infile
,
outfile
)
log
.
info
(
"generating %s -> %s"
,
infile
,
outfile
)
with
open
(
outfile
,
"w
b
"
)
as
f
:
with
open
(
outfile
,
"w"
)
as
f
:
for
x
in
sorted
(
version
.
ite
rite
ms
()):
for
x
in
sorted
(
version
.
items
()):
if
not
x
[
0
].
startswith
(
"_"
):
if
not
x
[
0
].
startswith
(
"_"
):
f
.
write
(
"%s = %r
\
n
"
%
x
)
f
.
write
(
"%s = %r
\
n
"
%
x
)
return
outfile
,
1
return
outfile
,
1
elif
isinstance
(
self
,
build_py
)
and
\
elif
isinstance
(
self
,
build_py
)
and
\
os
.
stat
(
infile
).
st_mode
&
stat
.
S_IEXEC
:
os
.
stat
(
infile
).
st_mode
&
stat
.
S_IEXEC
:
if
os
.
path
.
isdir
(
infile
)
and
os
.
path
.
isdir
(
outfile
):
if
os
.
path
.
isdir
(
infile
)
and
os
.
path
.
isdir
(
outfile
):
return
(
outfile
,
0
)
return
outfile
,
0
# Adjust interpreter of OpenVPN hooks.
# Adjust interpreter of OpenVPN hooks.
with
open
(
infile
)
as
src
:
with
open
(
infile
)
as
src
:
first_line
=
src
.
readline
()
first_line
=
src
.
readline
()
...
@@ -33,7 +35,7 @@ def copy_file(self, infile, outfile, *args, **kw):
...
@@ -33,7 +35,7 @@ def copy_file(self, infile, outfile, *args, **kw):
patched
+=
src
.
read
()
patched
+=
src
.
read
()
dst
=
os
.
open
(
outfile
,
os
.
O_CREAT
|
os
.
O_WRONLY
|
os
.
O_TRUNC
)
dst
=
os
.
open
(
outfile
,
os
.
O_CREAT
|
os
.
O_WRONLY
|
os
.
O_TRUNC
)
try
:
try
:
os
.
write
(
dst
,
patched
)
os
.
write
(
dst
,
patched
.
encode
()
)
finally
:
finally
:
os
.
close
(
dst
)
os
.
close
(
dst
)
return
outfile
,
1
return
outfile
,
1
...
@@ -51,7 +53,8 @@ Environment :: Console
...
@@ -51,7 +53,8 @@ Environment :: Console
License :: OSI Approved :: GNU General Public License (GPL)
License :: OSI Approved :: GNU General Public License (GPL)
Natural Language :: English
Natural Language :: English
Operating System :: POSIX :: Linux
Operating System :: POSIX :: Linux
Programming Language :: Python :: 2.7
Programming Language :: Python :: 3
Programming Language :: Python :: 3.11
Topic :: Internet
Topic :: Internet
Topic :: System :: Networking
Topic :: System :: Networking
"""
"""
...
@@ -73,6 +76,7 @@ setup(
...
@@ -73,6 +76,7 @@ setup(
license
=
'GPL 2+'
,
license
=
'GPL 2+'
,
platforms
=
[
"any"
],
platforms
=
[
"any"
],
classifiers
=
classifiers
.
splitlines
(),
classifiers
=
classifiers
.
splitlines
(),
python_requires
=
'>=3.11'
,
long_description
=
".. contents::
\
n
\
n
"
+
open
(
'README.rst'
).
read
()
long_description
=
".. contents::
\
n
\
n
"
+
open
(
'README.rst'
).
read
()
+
"
\
n
"
+
open
(
'CHANGES.rst'
).
read
()
+
git_rev
,
+
"
\
n
"
+
open
(
'CHANGES.rst'
).
read
()
+
git_rev
,
packages
=
find_packages
(),
packages
=
find_packages
(),
...
@@ -95,7 +99,7 @@ setup(
...
@@ -95,7 +99,7 @@ setup(
extras_require
=
{
extras_require
=
{
'geoip'
:
[
'geoip2'
],
'geoip'
:
[
'geoip2'
],
'multicast'
:
[
'PyYAML'
],
'multicast'
:
[
'PyYAML'
],
'test'
:
[
'mock'
,
'
pathlib2'
,
'nemu'
,
'python-unshare'
,
'python-passfd
'
,
'multiping'
]
'test'
:
[
'mock'
,
'
nemu3'
,
'unshare
'
,
'multiping'
]
},
},
#dependency_links = [
#dependency_links = [
# "http://miniupnp.free.fr/files/download.php?file=miniupnpc-1.7.20120714.tar.gz#egg=miniupnpc-1.7",
# "http://miniupnp.free.fr/files/download.php?file=miniupnpc-1.7.20120714.tar.gz#egg=miniupnpc-1.7",
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment