Commit a7a86341 authored by Julien Muchembled's avatar Julien Muchembled

New protocol between nodes with authentication

parent 32ebb80b
...@@ -297,21 +297,27 @@ if len(sys.argv) > 1: ...@@ -297,21 +297,27 @@ if len(sys.argv) > 1:
elif self.path == '/tunnel.html': elif self.path == '/tunnel.html':
other = 'route' other = 'route'
gv = registry.Popen(('python', '-c', r"""if 1: gv = registry.Popen(('python', '-c', r"""if 1:
import math import math, json
from re6st.registry import RegistryClient from re6st.registry import RegistryClient
g = eval(RegistryClient('http://localhost/').topology()) g = json.loads(RegistryClient(
'http://localhost/').topology())
r = set(g.pop('', ()))
a = set()
for v in g.itervalues():
a.update(v)
g.update(dict.fromkeys(a.difference(g), ()))
print 'digraph {' print 'digraph {'
a = 2 * math.pi / len(g) a = 2 * math.pi / len(g)
z = 4 z = 4
m2 = '%u/80' % (2 << 64) m2 = '%u/80' % (2 << 64)
title = lambda n: '2|80' if n == m2 else n title = lambda n: '2|80' if n == m2 else n
g = sorted((title(k), v) for k, v in g.iteritems()) g = sorted((title(k), k in r, v) for k, v in g.iteritems())
for i, (n, p) in enumerate(g): for i, (n, r, v) in enumerate(g):
print '"%s"[pos="%s,%s!"%s];' % (title(n), print '"%s"[pos="%s,%s!"%s];' % (title(n),
z * math.cos(a * i), z * math.sin(a * i), z * math.cos(a * i), z * math.sin(a * i),
', style=dashed' if p is None else '') '' if r else ', style=dashed')
for p in p or (): for v in v:
print '"%s" -> "%s";' % (n, title(p)) print '"%s" -> "%s";' % (n, title(v))
print '}' print '}'
"""), stdout=subprocess.PIPE, cwd="..").communicate()[0] """), stdout=subprocess.PIPE, cwd="..").communicate()[0]
if gv: if gv:
......
...@@ -12,6 +12,7 @@ class PeerDB(object): ...@@ -12,6 +12,7 @@ class PeerDB(object):
logging.info('Initialize cache ...') logging.info('Initialize cache ...')
self._db = sqlite3.connect(db_path, isolation_level=None) self._db = sqlite3.connect(db_path, isolation_level=None)
self._db.text_factory = str
q = self._db.execute q = self._db.execute
q("PRAGMA synchronous = OFF") q("PRAGMA synchronous = OFF")
q("PRAGMA journal_mode = MEMORY") q("PRAGMA journal_mode = MEMORY")
......
...@@ -18,10 +18,10 @@ Authenticated communication: ...@@ -18,10 +18,10 @@ Authenticated communication:
- the last one that was really used by the client (!hello) - the last one that was really used by the client (!hello)
- the one of the last handshake (hello) - the one of the last handshake (hello)
""" """
import base64, hmac, hashlib, httplib, inspect, logging import base64, hmac, hashlib, httplib, inspect, json, logging
import mailbox, os, random, select, smtplib, socket, sqlite3 import mailbox, os, random, select, smtplib, socket, sqlite3
import string, struct, sys, threading, time, weakref import string, sys, threading, time, weakref
from collections import deque from collections import defaultdict, deque
from datetime import datetime from datetime import datetime
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
from email.mime.text import MIMEText from email.mime.text import MIMEText
...@@ -92,6 +92,20 @@ class RegistryServer(object): ...@@ -92,6 +92,20 @@ class RegistryServer(object):
self.onTimeout() self.onTimeout()
def sendto(self, prefix, code):
self.sock.sendto("%s\0%c" % (prefix, code), ('::1', tunnel.PORT))
def recv(self, code):
try:
prefix, msg = self.sock.recv(1<<16).split('\0', 1)
int(prefix, 2)
except ValueError:
pass
else:
if msg and ord(msg[0]) == code:
return prefix, msg[1:]
return None, None
def select(self, r, w, t): def select(self, r, w, t):
if self.timeout: if self.timeout:
t.append((self.timeout, self.onTimeout)) t.append((self.timeout, self.onTimeout))
...@@ -198,8 +212,7 @@ class RegistryServer(object): ...@@ -198,8 +212,7 @@ class RegistryServer(object):
def hello(self, client_prefix): def hello(self, client_prefix):
with self.lock: with self.lock:
cert = self.getCert(client_prefix) cert = self.getCert(client_prefix)
key = hashlib.sha1(struct.pack('Q', key = utils.newHmacSecret()
random.getrandbits(64))).digest()
self.sessions.setdefault(client_prefix, [])[1:] = key, self.sessions.setdefault(client_prefix, [])[1:] = key,
key = x509.encrypt(cert, key) key = x509.encrypt(cert, key)
sign = self.cert.sign(key) sign = self.cert.sign(key)
...@@ -349,26 +362,22 @@ class RegistryServer(object): ...@@ -349,26 +362,22 @@ class RegistryServer(object):
# (in case 'peers' is empty). # (in case 'peers' is empty).
peer = self.prefix peer = self.prefix
with self.lock: with self.lock:
address = utils.ipFromBin(self.network + peer), tunnel.PORT self.sendto(peer, 1)
self.sock.sendto('\2', address)
s = self.sock, s = self.sock,
timeout = 3 timeout = 3
end = timeout + time.time() end = timeout + time.time()
# Loop because there may be answers from previous requests. # Loop because there may be answers from previous requests.
while select.select(s, (), (), timeout)[0]: while select.select(s, (), (), timeout)[0]:
msg = self.sock.recv(1<<16) prefix, msg = self.recv(1)
if msg[0] == '\1': if prefix == peer:
try:
msg = msg[1:msg.index('\n')]
except ValueError:
continue
if msg.split()[0] == peer:
break break
timeout = max(0, end - time.time()) timeout = max(0, end - time.time())
else: else:
logging.info("Timeout while querying [%s]:%u", *address) logging.info("Timeout while querying address for %s/%s",
int(peer, 2), len(peer))
return return
cert = self.getCert(cn) cert = self.getCert(cn)
msg = "%s %s" % (peer, msg)
logging.info("Sending bootstrap peer: %s", msg) logging.info("Sending bootstrap peer: %s", msg)
return x509.encrypt(cert, msg) return x509.encrypt(cert, msg)
...@@ -387,59 +396,46 @@ class RegistryServer(object): ...@@ -387,59 +396,46 @@ class RegistryServer(object):
while True: while True:
r, w, _ = select.select(s, s if peers else (), (), 3) r, w, _ = select.select(s, s if peers else (), (), 3)
if r: if r:
ver, address = self.sock.recvfrom(1<<16) prefix, ver = self.recv(4)
address = utils.binFromIp(address[0]) if prefix:
if (address.startswith(self.network) and peer_dict[prefix] = ver
len(ver) > 1 and ver[0] in '\3\4' # BBB
):
try:
peer_dict[max(filter(address[len(self.network):]
.startswith, peer_dict),
key=len)] = ver[1:]
except ValueError:
pass
if w: if w:
x = peers.pop() prefix = peers.pop()
peer_dict[x] = None peer_dict[prefix] = None
x = utils.ipFromBin(self.network + x) self.sendto(prefix, 4)
try:
self.sock.sendto('\3', (x, tunnel.PORT))
except socket.error:
pass
elif not r: elif not r:
break break
return repr(peer_dict) return json.dumps(peer_dict)
@rpc @rpc
def topology(self): def topology(self):
with self.lock: p = lambda p: '%s/%s' % (int(p, 2), len(p))
peers = deque(('%u/%u' % (int(self.prefix, 2), len(self.prefix)),)) peers = deque((p(self.prefix),))
cookie = hex(random.randint(0, 1<<32))[2:] graph = defaultdict(set)
graph = dict.fromkeys(peers)
s = self.sock, s = self.sock,
with self.lock:
while True: while True:
r, w, _ = select.select(s, s if peers else (), (), 3) r, w, _ = select.select(s, s if peers else (), (), 3)
if r: if r:
answer = self.sock.recv(1<<16) prefix, x = self.recv(5)
if answer[0] == '\xfe': if prefix and x:
answer = answer[1:].split('\n')[:-1] prefix = p(prefix)
if len(answer) >= 3 and answer[0] == cookie: x = x.split()
x = answer[3:]
assert answer[1] not in x, (answer, graph)
graph[answer[1]] = x[:int(answer[2])]
x = set(x).difference(graph)
peers += x
graph.update(dict.fromkeys(x))
if w:
x = utils.binFromSubnet(peers.popleft())
x = utils.ipFromBin(self.network + x)
try: try:
self.sock.sendto('\xff%s\n' % cookie, (x, tunnel.PORT)) n = int(x.pop(0))
except socket.error: except ValueError:
pass continue
if n <= len(x) and prefix not in x:
graph[prefix].update(x[:n])
peers += set(x).difference(graph)
for x in x[n:]:
graph[x].add(prefix)
graph[''].add(prefix)
if w:
self.sendto(utils.binFromSubnet(peers.popleft()), 5)
elif not r: elif not r:
break break
return repr(graph) return json.dumps(dict((k, list(v)) for k, v in graph.iteritems()))
class RegistryClient(object): class RegistryClient(object):
......
This diff is collapsed.
import argparse, errno, logging, os, select as _select, shlex, signal import argparse, errno, hashlib, logging, os, select as _select, shlex, signal
import socket, struct, subprocess, sys, textwrap, threading, time, traceback import socket, struct, subprocess, sys, textwrap, threading, time, traceback
HMAC_LEN = len(hashlib.sha1('').digest())
try: try:
subprocess.CalledProcessError(0, '', '') subprocess.CalledProcessError(0, '', '')
except TypeError: # BBB: Python < 2.7 except TypeError: # BBB: Python < 2.7
...@@ -223,3 +226,10 @@ def parse_address(address_list): ...@@ -223,3 +226,10 @@ def parse_address(address_list):
def binFromSubnet(subnet): def binFromSubnet(subnet):
p, l = subnet.split('/') p, l = subnet.split('/')
return bin(int(p))[2:].rjust(int(l), '0') return bin(int(p))[2:].rjust(int(l), '0')
def newHmacSecret():
from random import getrandbits as g
pack = struct.Struct(">QQI").pack
assert len(pack(0,0,0)) == HMAC_LEN
return lambda x=None: pack(g(64) if x is None else x, g(64), g(32))
newHmacSecret = newHmacSecret()
import calendar, logging, os, subprocess, threading, time # -*- coding: utf-8 -*-
import calendar, hashlib, hmac, logging, os, struct, subprocess, threading, time
from collections import deque
from datetime import datetime
from OpenSSL import crypto from OpenSSL import crypto
from . import utils from . import utils
def newHmacSecret():
x = datetime.utcnow()
return utils.newHmacSecret(int(time.mktime(x.timetuple())) * 1000000
+ x.microsecond)
def networkFromCa(ca): def networkFromCa(ca):
return bin(ca.get_serial_number())[3:] return bin(ca.get_serial_number())[3:]
...@@ -65,9 +73,14 @@ def maybe_renew(path, cert, info, renew): ...@@ -65,9 +73,14 @@ def maybe_renew(path, cert, info, renew):
info, exc_info=exc_info) info, exc_info=exc_info)
return cert, time.time() + 86400 return cert, time.time() + 86400
class VerifyError(Exception): class VerifyError(Exception):
pass pass
class NewSessionError(Exception):
pass
class Cert(object): class Cert(object):
def __init__(self, ca, key, cert=None): def __init__(self, ca, key, cert=None):
...@@ -105,11 +118,13 @@ class Cert(object): ...@@ -105,11 +118,13 @@ class Cert(object):
"CA Certificate", registry.getCa) "CA Certificate", registry.getCa)
return min(next_renew, ca_renew) return min(next_renew, ca_renew)
def loadVerify(self, cert, strict=False): def loadVerify(self, cert, strict=False, type=crypto.FILETYPE_PEM):
try: try:
r = crypto.load_certificate(crypto.FILETYPE_PEM, cert) r = crypto.load_certificate(type, cert)
except crypto.Error: except crypto.Error:
raise VerifyError(None, None, 'unable to load certificate') raise VerifyError(None, None, 'unable to load certificate')
if type != crypto.FILETYPE_PEM:
cert = crypto.dump_certificate(crypto.FILETYPE_PEM, r)
p = openssl('verify', '-CAfile', self.ca_path) p = openssl('verify', '-CAfile', self.ca_path)
out, err = p.communicate(cert) out, err = p.communicate(cert)
if p.returncode or strict: if p.returncode or strict:
...@@ -132,3 +147,101 @@ class Cert(object): ...@@ -132,3 +147,101 @@ class Cert(object):
if p.returncode: if p.returncode:
raise subprocess.CalledProcessError(p.returncode, 'openssl', err) raise subprocess.CalledProcessError(p.returncode, 'openssl', err)
return out return out
class Peer(object):
"""
UDP: A ─────────────────────────────────────────────> B
hello0: 0, A
1, fingerprint(B), A
hello: 2, X = E(B)(secret), S(A)(X)
!hello: #, ver, type, value, HMAC(secret)(payload)
└──── payload ────┘
new secret > old secret
(concat timestamp with random bits)
Reject messages with # smaller or equal than previously processed.
Yes, we do UDP on purpose. The only drawbacks are:
- The limited size of packets, but they are big enough for a network
using 4096-bits RSA keys.
- hello0 packets (0 & 1) are subject to DoS, because verifying a
certificate uses much CPU. A solution would be to use TCP until the
secret is exchanged and continue with UDP.
"""
_hello = _last = 0
_key = newHmacSecret()
def __init__(self, prefix):
assert len(prefix) == 16 or prefix == ('0' * 14 + '1' + '0' * 65), prefix
self.prefix = prefix
@property
def connected(self):
return self._last is None or time.time() < self._last + 60
def __ne__(self, other):
raise AssertionError
__eq__ = __ge__ = __le__ = __ne__
def __gt__(self, other):
return self.prefix > (other if type(other) is str else other.prefix)
def __lt__(self, other):
return self.prefix < (other if type(other) is str else other.prefix)
def hello0(self, cert):
if self._hello < time.time():
try:
msg = '\0\0\0\1' + fingerprint(self.cert).digest()
except AttributeError:
msg = '\0\0\0\0'
return msg + crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)
def hello0Sent(self):
self._hello = time.time() + 60
def hello(self, cert):
key = self._key = newHmacSecret()
h = encrypt(crypto.dump_certificate(crypto.FILETYPE_PEM, self.cert),
key)
self._i = self._j = 2
self._last = 0
return '\0\0\0\2' + h + cert.sign(h)
def _hmac(self, msg):
return hmac.HMAC(self._key, msg, hashlib.sha1).digest()
def newSession(self, key):
if key <= self._key:
raise NewSessionError(self._key, key)
self._key = key
self._i = self._j = 2
self._last = None
def verify(self, sign, data):
crypto.verify(self.cert, sign, data, 'sha1')
seqno_struct = struct.Struct("!L")
def decode(self, msg, _unpack=seqno_struct.unpack):
seqno, = _unpack(msg[:4])
if seqno <= 2:
return seqno, msg[4:]
i = -utils.HMAC_LEN
if self._hmac(msg[:i]) == msg[i:] and self._i < seqno:
self._last = None
self._i = seqno
return msg[4:i]
def encode(self, msg, _pack=seqno_struct.pack):
self._j += 1
msg = _pack(self._j) + msg
return msg + self._hmac(msg)
del seqno_struct
def sent(self):
if not self._last:
self._last = time.time()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment