Commit 7f1d0c39 authored by Kirill Smelkov's avatar Kirill Smelkov

go/neo/neonet: MessagePack support for link layer (draft)

This patch adds support for serializing packet frames with M encoding on
the wire. To do so it follows rules defined in

    	nexedi/neoppod@9d0bf97a
    	( nexedi/neoppod!11 )

Server handshake is reworked to autodetect client's preferred encoding.
Client always prefers 'N' for now.
parent a1ef272f
...@@ -100,9 +100,11 @@ import ( ...@@ -100,9 +100,11 @@ import (
"lab.nexedi.com/kirr/neo/go/internal/packed" "lab.nexedi.com/kirr/neo/go/internal/packed"
"lab.nexedi.com/kirr/neo/go/internal/xio" "lab.nexedi.com/kirr/neo/go/internal/xio"
"lab.nexedi.com/kirr/neo/go/neo/internal/msgpack"
"lab.nexedi.com/kirr/neo/go/neo/proto" "lab.nexedi.com/kirr/neo/go/neo/proto"
"github.com/someonegg/gocontainer/rbuf" "github.com/someonegg/gocontainer/rbuf"
"github.com/tinylib/msgp/msgp"
"lab.nexedi.com/kirr/go123/xbytes" "lab.nexedi.com/kirr/go123/xbytes"
) )
...@@ -123,7 +125,7 @@ import ( ...@@ -123,7 +125,7 @@ import (
// It is safe to use NodeLink from multiple goroutines simultaneously. // It is safe to use NodeLink from multiple goroutines simultaneously.
type NodeLink struct { type NodeLink struct {
peerLink net.Conn // raw conn to peer peerLink net.Conn // raw conn to peer
enc proto.Encoding // protocol encoding in use ('N') enc proto.Encoding // protocol encoding in use ('N' or 'M')
connMu sync.Mutex connMu sync.Mutex
connTab map[uint32]*Conn // connId -> Conn associated with connId connTab map[uint32]*Conn // connId -> Conn associated with connId
...@@ -153,6 +155,8 @@ type NodeLink struct { ...@@ -153,6 +155,8 @@ type NodeLink struct {
closed atomic32 // whether Close was called closed atomic32 // whether Close was called
rxbufN rbuf.RingBuf // buffer for reading from peerLink (N encoding) rxbufN rbuf.RingBuf // buffer for reading from peerLink (N encoding)
rxbufM *msgp.Reader // ----//---- (M encoding)
rxbufMlimit *io.LimitedReader // limiter inserted inbetween rxbufM and peerLink
// scheduling optimization: whenever serveRecv sends to Conn.rxq // scheduling optimization: whenever serveRecv sends to Conn.rxq
// receiving side must ack here to receive G handoff. // receiving side must ack here to receive G handoff.
...@@ -261,7 +265,9 @@ const ( ...@@ -261,7 +265,9 @@ const (
// //
// Though it is possible to wrap just-established raw connection into NodeLink, // Though it is possible to wrap just-established raw connection into NodeLink,
// users should always use Handshake which performs protocol handshaking first. // users should always use Handshake which performs protocol handshaking first.
func newNodeLink(conn net.Conn, enc proto.Encoding, role _LinkRole) *NodeLink { //
// rxbuf if != nil indicates what was already read-buffered from conn.
func newNodeLink(conn net.Conn, enc proto.Encoding, role _LinkRole, rxbuf *xbufReader) *NodeLink {
var nextConnId uint32 var nextConnId uint32
switch role &^ linkFlagsMask { switch role &^ linkFlagsMask {
case _LinkServer: case _LinkServer:
...@@ -283,6 +289,27 @@ func newNodeLink(conn net.Conn, enc proto.Encoding, role _LinkRole) *NodeLink { ...@@ -283,6 +289,27 @@ func newNodeLink(conn net.Conn, enc proto.Encoding, role _LinkRole) *NodeLink {
// axdown: make(chan struct{}), // axdown: make(chan struct{}),
down: make(chan struct{}), down: make(chan struct{}),
} }
if rxbuf == nil {
rxbuf = newXBufReader(conn, 0)
}
switch enc {
case 'N':
// rxbufN <- rxbufM (what was preread)
b, err := rxbuf.Next(rxbuf.Buffered())
if err != nil {
panic(err) // must not fail
}
nl.rxbufN.Write(b)
case 'M':
nl.rxbufM = &msgp.Reader{R: rxbuf.Reader}
nl.rxbufMlimit = &rxbuf.Limit
nl.rxbufMlimit.N = 0 // reads will fail unless .N is explicitly reset
default:
panic("bug")
}
if role&linkNoRecvSend == 0 { if role&linkNoRecvSend == 0 {
nl.serveWg.Add(2) nl.serveWg.Add(2)
go nl.serveRecv() go nl.serveRecv()
...@@ -1190,6 +1217,7 @@ var ErrPktTooBig = errors.New("packet too big") ...@@ -1190,6 +1217,7 @@ var ErrPktTooBig = errors.New("packet too big")
func (nl *NodeLink) recvPkt() (pkt *pktBuf, err error) { func (nl *NodeLink) recvPkt() (pkt *pktBuf, err error) {
switch nl.enc { switch nl.enc {
case 'N': pkt, err = nl.recvPktN() case 'N': pkt, err = nl.recvPktN()
case 'M': pkt, err = nl.recvPktM()
default: panic("bug") default: panic("bug")
} }
...@@ -1272,6 +1300,25 @@ func (nl *NodeLink) recvPktN() (*pktBuf, error) { ...@@ -1272,6 +1300,25 @@ func (nl *NodeLink) recvPktN() (*pktBuf, error) {
return pkt, nil return pkt, nil
} }
func (nl *NodeLink) recvPktM() (*pktBuf, error) {
pkt := pktAlloc(4096)
mraw := msgp.Raw(pkt.data)
// limit size of one packet to proto.PktMaxSize
// we don't care if it will be slightly more with what is already buffered
nl.rxbufMlimit.N = proto.PktMaxSize
err := mraw.DecodeMsg(nl.rxbufM)
if err != nil {
if nl.rxbufMlimit.N <= 0 {
err = ErrPktTooBig
}
return nil, err
}
pkt.data = []byte(mraw)
return pkt, nil
}
// ---- for convenience: Conn -> NodeLink & local/remote link addresses ---- // ---- for convenience: Conn -> NodeLink & local/remote link addresses ----
...@@ -1343,6 +1390,7 @@ func (c *Conn) err(op string, e error) error { ...@@ -1343,6 +1390,7 @@ func (c *Conn) err(op string, e error) error {
func pktEncode(e proto.Encoding, connId uint32, msg proto.Msg) *pktBuf { func pktEncode(e proto.Encoding, connId uint32, msg proto.Msg) *pktBuf {
switch e { switch e {
case 'N': return pktEncodeN(connId, msg) case 'N': return pktEncodeN(connId, msg)
case 'M': return pktEncodeM(connId, msg)
default: panic("bug") default: panic("bug")
} }
} }
...@@ -1351,6 +1399,7 @@ func pktEncode(e proto.Encoding, connId uint32, msg proto.Msg) *pktBuf { ...@@ -1351,6 +1399,7 @@ func pktEncode(e proto.Encoding, connId uint32, msg proto.Msg) *pktBuf {
func pktDecodeHead(e proto.Encoding, pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) { func pktDecodeHead(e proto.Encoding, pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) {
switch e { switch e {
case 'N': connID, msgCode, payload, err = pktDecodeHeadN(pkt) case 'N': connID, msgCode, payload, err = pktDecodeHeadN(pkt)
case 'M': connID, msgCode, payload, err = pktDecodeHeadM(pkt)
default: panic("bug") default: panic("bug")
} }
...@@ -1391,6 +1440,52 @@ func pktDecodeHeadN(pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, ...@@ -1391,6 +1440,52 @@ func pktDecodeHeadN(pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte,
return return
} }
func pktEncodeM(connId uint32, msg proto.Msg) *pktBuf {
const enc = proto.Encoding('M')
// [3](connID, msgCode, argv)
msgCode := proto.MsgCode(msg)
hroom := msgpack.ArrayHeadSize(3) +
msgpack.Uint32Size(connId) +
msgpack.Uint16Size(msgCode)
l := enc.MsgEncodedLen(msg)
buf := pktAlloc(hroom + l)
b := buf.data
i := 0
i += msgpack.PutArrayHead (b[i:], 3)
i += msgpack.PutUint32 (b[i:], connId)
i += msgpack.PutUint16 (b[i:], msgCode)
if i != hroom {
panic("bug")
}
enc.MsgEncode(msg, b[hroom:])
return buf
}
func pktDecodeHeadM(pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) {
b := pkt.data
sz, b, err := msgp.ReadArrayHeaderBytes(b)
if err != nil {
return 0, 0, nil, err
}
if sz != 3 {
return 0, 0, nil, fmt.Errorf("expected [3]tuple, got [%d]tuple", sz)
}
connID, b, err = msgp.ReadUint32Bytes(b)
if err != nil {
return 0, 0, nil, fmt.Errorf("connID: %s", err)
}
msgCode, b, err = msgp.ReadUint16Bytes(b)
if err != nil {
return 0, 0, nil, fmt.Errorf("msgCode: %s", err)
}
return connID, msgCode, b, nil
}
// Recv receives message from the connection. // Recv receives message from the connection.
func (c *Conn) Recv() (proto.Msg, error) { func (c *Conn) Recv() (proto.Msg, error) {
......
...@@ -39,6 +39,8 @@ import ( ...@@ -39,6 +39,8 @@ import (
"lab.nexedi.com/kirr/neo/go/neo/proto" "lab.nexedi.com/kirr/neo/go/neo/proto"
"lab.nexedi.com/kirr/neo/go/zodb" "lab.nexedi.com/kirr/neo/go/zodb"
"github.com/tinylib/msgp/msgp"
"github.com/kylelemons/godebug/pretty" "github.com/kylelemons/godebug/pretty"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
...@@ -52,7 +54,7 @@ type T struct { ...@@ -52,7 +54,7 @@ type T struct {
// Verify tests f for all possible environments. // Verify tests f for all possible environments.
func Verify(t *testing.T, f func(*T)) { func Verify(t *testing.T, f func(*T)) {
// for each encoding // for each encoding
for _, enc := range []proto.Encoding{'N'} { for _, enc := range []proto.Encoding{'N', 'M'} {
t.Run(fmt.Sprintf("enc=%c", enc), func(t *testing.T) { t.Run(fmt.Sprintf("enc=%c", enc), func(t *testing.T) {
f(&T{t, enc}) f(&T{t, enc})
}) })
...@@ -63,6 +65,7 @@ func Verify(t *testing.T, f func(*T)) { ...@@ -63,6 +65,7 @@ func Verify(t *testing.T, f func(*T)) {
func (t *T) bin(data string) []byte { func (t *T) bin(data string) []byte {
switch t.enc { switch t.enc {
case 'N': return []byte(data) case 'N': return []byte(data)
case 'M': return msgp.AppendBytes(nil, []byte(data))
default: panic("bug") default: panic("bug")
} }
} }
...@@ -139,6 +142,15 @@ func _mkpkt(enc proto.Encoding, connid uint32, msgcode uint16, payload []byte) * ...@@ -139,6 +142,15 @@ func _mkpkt(enc proto.Encoding, connid uint32, msgcode uint16, payload []byte) *
copy(pkt.PayloadN(), payload) copy(pkt.PayloadN(), payload)
return pkt return pkt
case 'M':
var b []byte
b = msgp.AppendArrayHeader (b, 3)
b = msgp.AppendUint32 (b, connid)
b = msgp.AppendUint16 (b, msgcode)
// NOTE payload is assumed to be valid msgpack-encoded object.
b = append (b, payload...)
return &pktBuf{b}
default: default:
panic("bug") panic("bug")
} }
...@@ -195,8 +207,8 @@ func tdelay() { ...@@ -195,8 +207,8 @@ func tdelay() {
// create NodeLinks connected via net.Pipe; messages are encoded via t.enc. // create NodeLinks connected via net.Pipe; messages are encoded via t.enc.
func (t *T) _nodeLinkPipe(flags1, flags2 _LinkRole) (nl1, nl2 *NodeLink) { func (t *T) _nodeLinkPipe(flags1, flags2 _LinkRole) (nl1, nl2 *NodeLink) {
node1, node2 := net.Pipe() node1, node2 := net.Pipe()
nl1 = newNodeLink(node1, t.enc, _LinkClient|flags1) nl1 = newNodeLink(node1, t.enc, _LinkClient|flags1, nil)
nl2 = newNodeLink(node2, t.enc, _LinkServer|flags2) nl2 = newNodeLink(node2, t.enc, _LinkServer|flags2, nil)
return nl1, nl2 return nl1, nl2
} }
......
// Copyright (C) 2016-2018 Nexedi SA and Contributors. // Copyright (C) 2016-2021 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com> // Kirill Smelkov <kirr@nexedi.com>
// //
// This program is free software: you can Use, Study, Modify and Redistribute // This program is free software: you can Use, Study, Modify and Redistribute
...@@ -20,7 +20,13 @@ ...@@ -20,7 +20,13 @@
package neonet package neonet
// syntax sugar for atomic load/store to raise signal/noise in logic // syntax sugar for atomic load/store to raise signal/noise in logic
import "sync/atomic" import (
"io"
"sync/atomic"
"github.com/philhofer/fwd"
)
type atomic32 struct { type atomic32 struct {
v int32 // struct member so `var a atomic32; if a == 0 ...` does not work v int32 // struct member so `var a atomic32; if a == 0 ...` does not work
...@@ -37,3 +43,16 @@ func (a *atomic32) Set(v int32) { ...@@ -37,3 +43,16 @@ func (a *atomic32) Set(v int32) {
func (a *atomic32) Add(δ int32) int32 { func (a *atomic32) Add(δ int32) int32 {
return atomic.AddInt32(&a.v, δ) return atomic.AddInt32(&a.v, δ)
} }
// xbufReader provides fwd.Reader with io.LimitedReader inserted underneath it.
type xbufReader struct {
*fwd.Reader
Limit io.LimitedReader // .Reader reads through .limit
}
func newXBufReader(r io.Reader, n int64) *xbufReader {
rxbuf := &xbufReader{Limit: io.LimitedReader{R: r, N: n}}
rxbuf.Reader = fwd.NewReader(&rxbuf.Limit)
return rxbuf
}
...@@ -21,12 +21,14 @@ package neonet ...@@ -21,12 +21,14 @@ package neonet
// link establishment // link establishment
import ( import (
"bytes"
"context" "context"
"encoding/binary"
"fmt" "fmt"
"io" "io"
"net" "net"
"github.com/tinylib/msgp/msgp"
"lab.nexedi.com/kirr/go123/xerr" "lab.nexedi.com/kirr/go123/xerr"
"lab.nexedi.com/kirr/go123/xnet" "lab.nexedi.com/kirr/go123/xnet"
"lab.nexedi.com/kirr/neo/go/internal/xcontext" "lab.nexedi.com/kirr/neo/go/internal/xcontext"
...@@ -41,7 +43,7 @@ import ( ...@@ -41,7 +43,7 @@ import (
// do not have such uses. // do not have such uses.
func _HandshakeClient(ctx context.Context, conn net.Conn) (*NodeLink, error) { func _HandshakeClient(ctx context.Context, conn net.Conn) (*NodeLink, error) {
return handshakeClient(ctx, conn, proto.Version) return handshakeClient(ctx, conn, proto.Version, proto.Encoding('N'))
} }
func _HandshakeServer(ctx context.Context, conn net.Conn) (*NodeLink, error) { func _HandshakeServer(ctx context.Context, conn net.Conn) (*NodeLink, error) {
...@@ -72,52 +74,54 @@ func (e *_HandshakeError) Unwrap() error { return e.Err } ...@@ -72,52 +74,54 @@ func (e *_HandshakeError) Unwrap() error { return e.Err }
// handshakeClient implements client-side NEO protocol handshake just after raw // handshakeClient implements client-side NEO protocol handshake just after raw
// connection between 2 nodes was established. // connection between 2 nodes was established.
// //
// Client indicates its version to server. // Client indicates its version and preferred encoding, but accepts any
// encoding chosen to use by server.
// //
// On success raw connection is returned wrapped into NodeLink. // On success raw connection is returned wrapped into NodeLink.
// On error raw connection is closed. // On error raw connection is closed.
func handshakeClient(ctx context.Context, conn net.Conn, version uint32) (*NodeLink, error) { func handshakeClient(ctx context.Context, conn net.Conn, version uint32, encPrefer proto.Encoding) (*NodeLink, error) {
err := _handshakeClient(ctx, conn, version) enc, rxbuf, err := _handshakeClient(ctx, conn, version, encPrefer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
enc := proto.Encoding('N') return newNodeLink(conn, enc, _LinkClient, rxbuf), nil
return newNodeLink(conn, enc, _LinkClient), nil
} }
// handshakeServer implements server-side NEO protocol handshake just after raw // handshakeServer implements server-side NEO protocol handshake just after raw
// connection between 2 nodes was established. // connection between 2 nodes was established.
// //
// Server verifies that its version matches Client. // Server verifies that its version matches Client and accepts client preferred encoding.
// //
// On success raw connection is returned wrapped into NodeLink. // On success raw connection is returned wrapped into NodeLink.
// On error raw connection is closed. // On error raw connection is closed.
func handshakeServer(ctx context.Context, conn net.Conn, version uint32) (*NodeLink, error) { func handshakeServer(ctx context.Context, conn net.Conn, version uint32) (*NodeLink, error) {
err := _handshakeServer(ctx, conn, version) enc, rxbuf, err := _handshakeServer(ctx, conn, version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
enc := proto.Encoding('N') return newNodeLink(conn, enc, _LinkServer, rxbuf), nil
return newNodeLink(conn, enc, _LinkServer), nil
} }
func _handshakeClient(ctx context.Context, conn net.Conn, version uint32) (err error) { func _handshakeClient(ctx context.Context, conn net.Conn, version uint32, encPrefer proto.Encoding) (enc proto.Encoding, rxbuf *xbufReader, err error) {
defer func() { defer func() {
if err != nil { if err != nil {
err = &_HandshakeError{_LinkClient, conn.LocalAddr(), conn.RemoteAddr(), err} err = &_HandshakeError{_LinkClient, conn.LocalAddr(), conn.RemoteAddr(), err}
} }
}() }()
rxbuf = newXBufReader(conn, /*any non-small limit*/1024)
var peerEnc proto.Encoding
err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error { err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error {
// tx client hello // tx client hello
err := txHello("tx hello", conn, version) err := txHello("tx hello", conn, version, encPrefer)
if err != nil { if err != nil {
return err return err
} }
// rx server hello reply // rx server hello reply
var peerVer uint32 var peerVer uint32
peerVer, err = rxHello("rx hello reply", conn) peerEnc, peerVer, err = rxHello("rx hello reply", rxbuf)
if err != nil { if err != nil {
return err return err
} }
...@@ -130,24 +134,29 @@ func _handshakeClient(ctx context.Context, conn net.Conn, version uint32) (err e ...@@ -130,24 +134,29 @@ func _handshakeClient(ctx context.Context, conn net.Conn, version uint32) (err e
return nil return nil
}) })
if err != nil { if err != nil {
return err return 0, nil, err
} }
return nil // use peer encoding (server should return the same, but we are ok if
// it asks to switch to different)
return peerEnc, rxbuf, nil
} }
func _handshakeServer(ctx context.Context, conn net.Conn, version uint32) (err error) { func _handshakeServer(ctx context.Context, conn net.Conn, version uint32) (enc proto.Encoding, rxbuf *xbufReader, err error) {
defer func() { defer func() {
if err != nil { if err != nil {
err = &_HandshakeError{_LinkServer, conn.LocalAddr(), conn.RemoteAddr(), err} err = &_HandshakeError{_LinkServer, conn.LocalAddr(), conn.RemoteAddr(), err}
} }
}() }()
rxbuf = newXBufReader(conn, /*any non-small limit*/1024)
var peerEnc proto.Encoding
err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error { err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error {
// rx client hello // rx client hello
var peerVer uint32 var peerVer uint32
var err error var err error
peerVer, err = rxHello("rx hello", conn) peerEnc, peerVer, err = rxHello("rx hello", rxbuf)
if err != nil { if err != nil {
return err return err
} }
...@@ -156,7 +165,7 @@ func _handshakeServer(ctx context.Context, conn net.Conn, version uint32) (err e ...@@ -156,7 +165,7 @@ func _handshakeServer(ctx context.Context, conn net.Conn, version uint32) (err e
// //
// do it before version check so that client can also detect "version // do it before version check so that client can also detect "version
// mismatch" instead of just getting "disconnect". // mismatch" instead of just getting "disconnect".
err = txHello("tx hello reply", conn, version) err = txHello("tx hello reply", conn, version, peerEnc)
if err != nil { if err != nil {
return err return err
} }
...@@ -169,19 +178,43 @@ func _handshakeServer(ctx context.Context, conn net.Conn, version uint32) (err e ...@@ -169,19 +178,43 @@ func _handshakeServer(ctx context.Context, conn net.Conn, version uint32) (err e
return nil return nil
}) })
if err != nil { if err != nil {
return err return 0, nil, err
} }
return nil return peerEnc, rxbuf, nil
} }
func txHello(errctx string, conn net.Conn, version uint32) (err error) { // handshake hello:
//
// - 00 00 00 <ver> for 'N' encoding, and
// - 92 c4 03 NEO ... for 'M' encoding (= msgpack of (b"NEO", <ver>))
//
// the first byte is different from TLS handshake (0x16).
func txHello(errctx string, conn net.Conn, version uint32, enc proto.Encoding) (err error) {
defer xerr.Context(&err, errctx) defer xerr.Context(&err, errctx)
var b [4]byte var b []byte
binary.BigEndian.PutUint32(b[:], version) // XXX -> hton32 ? switch enc {
case 'N':
// 00 00 00 <v>
b = make([]byte, 4)
if version > 0xff {
panic("encoding N supports versions only in range [0, 0xff]")
}
b[3] = uint8(version)
case 'M':
// (b"NEO", <V>) encoded as msgpack (= 92 c4 03 NEO int(<V>))
b = msgp.AppendArrayHeader(b, 2) // 92
b = msgp.AppendBytes(b, []byte("NEO")) // c4 03 NEO
b = msgp.AppendUint32(b, version) // u?intX version
default:
panic("bug")
}
_, err = conn.Write(b[:]) _, err = conn.Write(b)
if err != nil { if err != nil {
return err return err
} }
...@@ -189,18 +222,52 @@ func txHello(errctx string, conn net.Conn, version uint32) (err error) { ...@@ -189,18 +222,52 @@ func txHello(errctx string, conn net.Conn, version uint32) (err error) {
return nil return nil
} }
func rxHello(errctx string, conn net.Conn) (version uint32, err error) { func rxHello(errctx string, rx *xbufReader) (enc proto.Encoding, version uint32, err error) {
defer xerr.Context(&err, errctx) defer xerr.Context(&err, errctx)
var b [4]byte b := make([]byte, 4)
_, err = io.ReadFull(conn, b[:]) _, err = io.ReadFull(rx, b)
err = xio.NoEOF(err) err = xio.NoEOF(err)
if err != nil { if err != nil {
return 0, err return 0, 0, err
}
var peerEnc proto.Encoding
var peerVer uint32
badMagic := false
switch {
case bytes.Equal(b[:3], []byte{0,0,0}):
peerEnc = 'N'
peerVer = uint32(b[3])
case bytes.Equal(b, []byte{0x92, 0xc4, 3, 'N'}): // start of "fixarray<2> bin8 'N | EO' ...
b = append(b, []byte{0,0}...)
_, err = io.ReadFull(rx, b[4:])
err = xio.NoEOF(err)
if err != nil {
return 0, 0, err
}
if !bytes.Equal(b[4:], []byte{'E','O'}) {
badMagic = true
break
}
peerEnc = 'M'
rxM := msgp.Reader{R: rx.Reader}
peerVer, err = rxM.ReadUint32()
if err != nil {
return 0, 0, fmt.Errorf("M: recv peer version: %s", err) // XXX + "read magic" ctx
}
default:
badMagic = true
}
if badMagic {
return 0, 0, fmt.Errorf("invalid magic %x", b)
} }
peerVer := binary.BigEndian.Uint32(b[:]) // XXX -> ntoh32 ? return peerEnc, peerVer, nil
return peerVer, nil
} }
......
...@@ -28,18 +28,25 @@ import ( ...@@ -28,18 +28,25 @@ import (
"lab.nexedi.com/kirr/go123/exc" "lab.nexedi.com/kirr/go123/exc"
"lab.nexedi.com/kirr/go123/xsync" "lab.nexedi.com/kirr/go123/xsync"
"lab.nexedi.com/kirr/neo/go/neo/proto"
) )
// _xhandshakeClient handshakes as client. // _xhandshakeClient handshakes as client with encPrefer encoding and verifies that server accepts it.
func _xhandshakeClient(ctx context.Context, c net.Conn, version uint32) { func _xhandshakeClient(ctx context.Context, c net.Conn, version uint32, encPrefer proto.Encoding) {
err := _handshakeClient(ctx, c, version) enc, _, err := _handshakeClient(ctx, c, version, encPrefer)
exc.Raiseif(err) exc.Raiseif(err)
if enc != encPrefer {
exc.Raisef("enc (%c) != encPrefer (%c)", enc, encPrefer)
}
} }
// _xhandshakeServer handshakes as server. // _xhandshakeServer handshakes as server and verifies negotiated encoding to be encOK.
func _xhandshakeServer(ctx context.Context, c net.Conn, version uint32) { func _xhandshakeServer(ctx context.Context, c net.Conn, version uint32, encOK proto.Encoding) {
err := _handshakeServer(ctx, c, version) enc, _, err := _handshakeServer(ctx, c, version)
exc.Raiseif(err) exc.Raiseif(err)
if enc != encOK {
exc.Raisef("enc (%c) != encOK (%c)", enc, encOK)
}
} }
func TestHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
...@@ -51,10 +58,10 @@ func _TestHandshake(t *T) { ...@@ -51,10 +58,10 @@ func _TestHandshake(t *T) {
p1, p2 := net.Pipe() p1, p2 := net.Pipe()
wg := xsync.NewWorkGroup(bg) wg := xsync.NewWorkGroup(bg)
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
_xhandshakeClient(ctx, p1, 1) _xhandshakeClient(ctx, p1, 1, t.enc)
}) })
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
_xhandshakeServer(ctx, p2, 1) _xhandshakeServer(ctx, p2, 1, t.enc)
}) })
xwait(wg) xwait(wg)
xclose(p1) xclose(p1)
...@@ -65,10 +72,10 @@ func _TestHandshake(t *T) { ...@@ -65,10 +72,10 @@ func _TestHandshake(t *T) {
var err1, err2 error var err1, err2 error
wg = xsync.NewWorkGroup(bg) wg = xsync.NewWorkGroup(bg)
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
err1 = _handshakeClient(ctx, p1, 1) _, _, err1 = _handshakeClient(ctx, p1, 1, t.enc)
}) })
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
err2 = _handshakeServer(ctx, p2, 2) _, _, err2 = _handshakeServer(ctx, p2, 2)
}) })
xwait(wg) xwait(wg)
xclose(p1) xclose(p1)
...@@ -89,7 +96,7 @@ func _TestHandshake(t *T) { ...@@ -89,7 +96,7 @@ func _TestHandshake(t *T) {
var err error var err error
wg = xsync.NewWorkGroup(bg) wg = xsync.NewWorkGroup(bg)
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
err = _handshakeClient(ctx, p1, 1) _, _, err = _handshakeClient(ctx, p1, 1, t.enc)
}) })
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
xclose(p2) xclose(p2)
...@@ -109,7 +116,7 @@ func _TestHandshake(t *T) { ...@@ -109,7 +116,7 @@ func _TestHandshake(t *T) {
xclose(p1) xclose(p1)
}) })
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
err = _handshakeServer(ctx, p2, 1) _, _, err = _handshakeServer(ctx, p2, 1)
}) })
xwait(wg) xwait(wg)
xclose(p2) xclose(p2)
...@@ -124,7 +131,7 @@ func _TestHandshake(t *T) { ...@@ -124,7 +131,7 @@ func _TestHandshake(t *T) {
ctx, cancel := context.WithCancel(bg) ctx, cancel := context.WithCancel(bg)
wg = xsync.NewWorkGroup(ctx) wg = xsync.NewWorkGroup(ctx)
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
err = _handshakeClient(ctx, p1, 1) _, _, err = _handshakeClient(ctx, p1, 1, t.enc)
}) })
tdelay() tdelay()
cancel() cancel()
...@@ -142,7 +149,7 @@ func _TestHandshake(t *T) { ...@@ -142,7 +149,7 @@ func _TestHandshake(t *T) {
ctx, cancel = context.WithCancel(bg) ctx, cancel = context.WithCancel(bg)
wg = xsync.NewWorkGroup(ctx) wg = xsync.NewWorkGroup(ctx)
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
err = _handshakeServer(ctx, p2, 1) _, _, err = _handshakeServer(ctx, p2, 1)
}) })
tdelay() tdelay()
cancel() cancel()
......
...@@ -75,8 +75,7 @@ import ( ...@@ -75,8 +75,7 @@ import (
const ( const (
// The protocol version must be increased whenever upgrading a node may require // The protocol version must be increased whenever upgrading a node may require
// to upgrade other nodes. It is encoded as a 4-bytes big-endian integer and // to upgrade other nodes.
// the high order byte 0 is different from TLS Handshake (0x16).
Version = 6 Version = 6
// length of packet header in 'N'-encoding // length of packet header in 'N'-encoding
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment