Commit 4f327e1b authored by Julien Muchembled's avatar Julien Muchembled

Add protocol to handshake between nodes

There is a need to be able to extend the protocol without breaking
compatibility with old nodes. This is done by sending version.protocol
during inter-node handshake, in seqno 1 and seqno 2, so that a node
knows what version the peers speak and use appropriate format.

This is implemented with partial backward compatibility: handshake with
an old node succeeds when the new node does not have to send seqno 1.
parent ee93c63e
...@@ -20,7 +20,7 @@ Authenticated communication: ...@@ -20,7 +20,7 @@ Authenticated communication:
""" """
import base64, hmac, hashlib, httplib, inspect, json, logging import base64, hmac, hashlib, httplib, inspect, json, logging
import mailbox, os, platform, random, select, smtplib, socket, sqlite3 import mailbox, os, platform, random, select, smtplib, socket, sqlite3
import string, struct, sys, threading, time, weakref, zlib import string, sys, threading, time, weakref, zlib
from collections import defaultdict, deque from collections import defaultdict, deque
from datetime import datetime from datetime import datetime
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler
...@@ -132,8 +132,7 @@ class RegistryServer(object): ...@@ -132,8 +132,7 @@ class RegistryServer(object):
kw[x] = getattr(self.config, x) kw[x] = getattr(self.config, x)
config = json.dumps(kw, sort_keys=True) config = json.dumps(kw, sort_keys=True)
if config != self.getConfig('last_config', None): if config != self.getConfig('last_config', None):
self.version = self.encodeVersion( self.increaseVersion()
1 + self.decodeVersion(self.version))
# BBB: Use buffer because of http://bugs.python.org/issue13676 # BBB: Use buffer because of http://bugs.python.org/issue13676
# on Python 2.6 # on Python 2.6
self.setConfig('version', buffer(self.version)) self.setConfig('version', buffer(self.version))
...@@ -144,20 +143,9 @@ class RegistryServer(object): ...@@ -144,20 +143,9 @@ class RegistryServer(object):
kw['version'] = self.version.encode('base64') kw['version'] = self.version.encode('base64')
self.network_config = kw self.network_config = kw
# The 3 first bits code the number of bytes. def increaseVersion(self):
def encodeVersion(self, version): x = utils.packInteger(1 + utils.unpackInteger(self.version)[0])
for n in xrange(8): self.version = x + self.cert.sign(x)
x = 32 << 8 * n
if version < x:
x = struct.pack("!Q", version + n * x)[7-n:]
return x + self.cert.sign(x)
version -= x
def decodeVersion(self, version):
n = ord(version[0]) >> 5
version, = struct.unpack("!Q", '\0' * (7 - n) + version[:n+1])
return sum((32 << 8 * n for n in xrange(n)),
version - (n * 32 << 8 * n))
def sendto(self, prefix, code): def sendto(self, prefix, code):
self.sock.sendto("%s\0%c" % (prefix, code), ('::1', tunnel.PORT)) self.sock.sendto("%s\0%c" % (prefix, code), ('::1', tunnel.PORT))
...@@ -584,8 +572,7 @@ class RegistryServer(object): ...@@ -584,8 +572,7 @@ class RegistryServer(object):
# Initialization of HMAC on the network # Initialization of HMAC on the network
self.newHMAC(1) self.newHMAC(1)
self.newHMAC(2, '') self.newHMAC(2, '')
self.version = self.encodeVersion( self.increaseVersion()
1 + self.decodeVersion(self.version))
self.setConfig('version', buffer(self.version)) self.setConfig('version', buffer(self.version))
self.network_config['version'] = self.version.encode('base64') self.network_config['version'] = self.version.encode('base64')
self.sendto(self.prefix, 0) self.sendto(self.prefix, 0)
......
...@@ -374,7 +374,7 @@ class BaseTunnelManager(object): ...@@ -374,7 +374,7 @@ class BaseTunnelManager(object):
return return
try: try:
sender = utils.binFromIp(address[0]) sender = utils.binFromIp(address[0])
except socket.error, e: except socket.error:
return # inet_pton does not parse '<ipv6>%<iface>' return # inet_pton does not parse '<ipv6>%<iface>'
if len(msg) <= 4 or not sender.startswith(self._network): if len(msg) <= 4 or not sender.startswith(self._network):
return return
...@@ -382,7 +382,8 @@ class BaseTunnelManager(object): ...@@ -382,7 +382,8 @@ class BaseTunnelManager(object):
peer = self._getPeer(prefix) peer = self._getPeer(prefix)
msg = peer.decode(msg) msg = peer.decode(msg)
if type(msg) is tuple: if type(msg) is tuple:
seqno, msg = msg real_seqno, msg = msg
def handleHello(peer, seqno, msg):
if seqno == 2: if seqno == 2:
i = len(msg) // 2 i = len(msg) // 2
h = msg[:i] h = msg[:i]
...@@ -409,6 +410,8 @@ class BaseTunnelManager(object): ...@@ -409,6 +410,8 @@ class BaseTunnelManager(object):
if serial in self.cache.crl: if serial in self.cache.crl:
raise ValueError("revoked") raise ValueError("revoked")
except (x509.VerifyError, ValueError), e: except (x509.VerifyError, ValueError), e:
if real_seqno and peer.hello_protocol:
return True
logging.debug('ignored invalid certificate from %r (%s)', logging.debug('ignored invalid certificate from %r (%s)',
address, e.args[-1]) address, e.args[-1])
return return
...@@ -430,6 +433,12 @@ class BaseTunnelManager(object): ...@@ -430,6 +433,12 @@ class BaseTunnelManager(object):
msg = peer.hello0(self.cert.cert) msg = peer.hello0(self.cert.cert)
if msg and self._sendto(to, msg): if msg and self._sendto(to, msg):
peer.hello0Sent() peer.hello0Sent()
if handleHello(peer, real_seqno, msg):
# It is possible to reconstruct the original message because
# the serialization of the protocol version is always unique.
msg = utils.packInteger(peer.hello_protocol) + msg
peer.hello_protocol = 0
handleHello(peer, real_seqno, msg)
elif msg: elif msg:
# We got a valid and non-empty message. Always reply # We got a valid and non-empty message. Always reply
# something so that the sender knows we're still connected. # something so that the sender knows we're still connected.
......
...@@ -254,6 +254,31 @@ def newHmacSecret(): ...@@ -254,6 +254,31 @@ def newHmacSecret():
return lambda x=None: pack(g(64) if x is None else x, g(64), g(32)) return lambda x=None: pack(g(64) if x is None else x, g(64), g(32))
newHmacSecret = newHmacSecret() newHmacSecret = newHmacSecret()
### Integer serialization
# - supports values from 0 to 0x202020202020201f
# - preserves ordering
# - there's always a unique way to encode a value
# - the 3 first bits code the number of bytes
def packInteger(i):
for n in xrange(8):
x = 32 << 8 * n
if i < x:
return struct.pack("!Q", i + n * x)[7-n:]
i -= x
raise OverflowError
def unpackInteger(x):
n = ord(x[0]) >> 5
try:
i, = struct.unpack("!Q", '\0' * (7 - n) + x[:n+1])
except struct.error:
return
return sum((32 << 8 * i for i in xrange(n)),
i - (n * 32 << 8 * n)), n + 1
###
def sqliteCreateTable(db, name, *columns): def sqliteCreateTable(db, name, *columns):
sql = "CREATE TABLE %s (%s)" % (name, ','.join('\n ' + x for x in columns)) sql = "CREATE TABLE %s (%s)" % (name, ','.join('\n ' + x for x in columns))
for x, in db.execute( for x, in db.execute(
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import calendar, hashlib, hmac, logging, os, struct, subprocess, threading, time import calendar, hashlib, hmac, logging, os, struct, subprocess, threading, time
from OpenSSL import crypto from OpenSSL import crypto
from . import utils from . import utils
from .version import protocol
def newHmacSecret(): def newHmacSecret():
return utils.newHmacSecret(int(time.time() * 1000000)) return utils.newHmacSecret(int(time.time() * 1000000))
...@@ -171,6 +172,9 @@ class Cert(object): ...@@ -171,6 +172,9 @@ class Cert(object):
raise VerifyError(None, None, 'invalid network version') raise VerifyError(None, None, 'invalid network version')
PACKED_PROTOCOL = utils.packInteger(protocol)
class Peer(object): class Peer(object):
""" """
UDP: A ─────────────────────────────────────────────> B UDP: A ─────────────────────────────────────────────> B
...@@ -225,7 +229,11 @@ class Peer(object): ...@@ -225,7 +229,11 @@ class Peer(object):
def hello0(self, cert): def hello0(self, cert):
if self._hello < time.time(): if self._hello < time.time():
try: try:
msg = '\0\0\0\1' + fingerprint(self.cert).digest() # Always assume peer is not old, in case it has just upgraded,
# else we would be stuck with the old protocol.
msg = ('\0\0\0\1'
+ PACKED_PROTOCOL
+ fingerprint(self.cert).digest())
except AttributeError: except AttributeError:
msg = '\0\0\0\0' msg = '\0\0\0\0'
return msg + crypto.dump_certificate(crypto.FILETYPE_ASN1, cert) return msg + crypto.dump_certificate(crypto.FILETYPE_ASN1, cert)
...@@ -239,7 +247,9 @@ class Peer(object): ...@@ -239,7 +247,9 @@ class Peer(object):
key) key)
self._i = self._j = 2 self._i = self._j = 2
self._last = 0 self._last = 0
return '\0\0\0\2' + h + cert.sign(h) self.protocol = self.hello_protocol
return ''.join(('\0\0\0\2', PACKED_PROTOCOL if self.protocol else '',
h, cert.sign(h)))
def _hmac(self, msg): def _hmac(self, msg):
return hmac.HMAC(self._key, msg, hashlib.sha1).digest() return hmac.HMAC(self._key, msg, hashlib.sha1).digest()
...@@ -250,6 +260,7 @@ class Peer(object): ...@@ -250,6 +260,7 @@ class Peer(object):
self._key = key self._key = key
self._i = self._j = 2 self._i = self._j = 2
self._last = None self._last = None
self.protocol = self.hello_protocol
def verify(self, sign, data): def verify(self, sign, data):
crypto.verify(self.cert, sign, data, 'sha512') crypto.verify(self.cert, sign, data, 'sha512')
...@@ -259,7 +270,11 @@ class Peer(object): ...@@ -259,7 +270,11 @@ class Peer(object):
def decode(self, msg, _unpack=seqno_struct.unpack): def decode(self, msg, _unpack=seqno_struct.unpack):
seqno, = _unpack(msg[:4]) seqno, = _unpack(msg[:4])
if seqno <= 2: if seqno <= 2:
return seqno, msg[4:] msg = msg[4:]
if seqno:
self.hello_protocol, n = utils.unpackInteger(msg) or (0, 0)
msg = msg[n:]
return seqno, msg
i = -utils.HMAC_LEN i = -utils.HMAC_LEN
if self._hmac(msg[:i]) == msg[i:] and self._i < seqno: if self._hmac(msg[:i]) == msg[i:] and self._i < seqno:
self._last = None self._last = None
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment