Commit 1cf2712f authored by Kirill Smelkov's avatar Kirill Smelkov

X on msgpack support

parent d2697535
......@@ -27,6 +27,7 @@ package xcontext
import (
"context"
"errors"
"io"
)
// Cancelled reports whether an error is due to a canceled context.
......@@ -72,3 +73,41 @@ func WhenDone(ctx context.Context, f func()) func() {
close(done)
}
}
// WithCloseOnErrCancel closes c on ctx cancel while f is run, or if f returns with an error.
//
// It is usually handy to propagate cancellation to interrupt IO.
// XXX naming?
// XXX don't close on f return?
func WithCloseOnErrCancel(ctx context.Context, c io.Closer, f func() error) (err error) {
closed := false
fdone := make(chan error)
defer func() {
<-fdone // wait for f to complete
if err != nil {
if !closed {
c.Close()
}
}
}()
go func() (err error) {
defer func() {
fdone <- err
close(fdone)
}()
return f()
}()
select {
case <-ctx.Done():
c.Close() // interrupt IO
closed = true
return ctx.Err()
case err := <-fdone:
return err
}
}
......@@ -489,6 +489,7 @@ func withNEO(t *testing.T, f func(t *testing.T, nsrv NEOSrv, ndrv *Client), optv
withNEOSrv(t, func(t *testing.T, nsrv NEOSrv) {
t.Helper()
X := xtesting.FatalIf(t)
// TODO test for enc=(M|N) (XXX M|N only for NEO/go as NEO/py does not support autodetect)
ndrv, _, err := neoOpen(nsrv.URL(),
&zodb.DriverOptions{ReadOnly: true}); X(err)
defer func() {
......
// Copyright (C) 2020-2021 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
// it under the terms of the GNU General Public License version 3, or (at your
// option) any later version, as published by the Free Software Foundation.
//
// You can also Link and Combine this program with other software covered by
// the terms of any of the Free Software licenses or any of the Open Source
// Initiative approved licenses and Convey the resulting work. Corresponding
// source of such a combination shall include the source code for all other
// software used.
//
// This program is distributed WITHOUT ANY WARRANTY; without even the implied
// warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
//
// See COPYING file for full licensing terms.
// See https://www.nexedi.com/licensing for rationale and options.
// Package msgpack complements tinylib/msgp in providing runtime support for MessagePack.
//
// https://github.com/msgpack/msgpack/blob/master/spec.md
package msgpack
import (
"encoding/binary"
"math"
)
// Op represents a MessagePack opcode.
type Op byte
const (
FixMap_4 Op = 0b1000_0000 // 1000_XXXX
FixArray_4 Op = 0b1001_0000 // 1001_XXXX
False Op = 0xc2
True Op = 0xc3
Bin8 Op = 0xc4
Bin16 Op = 0xc5
Bin32 Op = 0xc6
Float32 Op = 0xca
Float64 Op = 0xcb
Uint8 Op = 0xcc
Uint16 Op = 0xcd
Uint32 Op = 0xce
Uint64 Op = 0xcf
Int8 Op = 0xd0
Int16 Op = 0xd1
Int32 Op = 0xd2
Int64 Op = 0xd3
FixExt1 Op = 0xd4
FixExt2 Op = 0xd5
FixExt4 Op = 0xd6
Array16 Op = 0xdc
Array32 Op = 0xdd
Map16 Op = 0xde
Map32 Op = 0xdf
)
// op converts Op into byte.
// it is used internally to make sure that only Op is put into encoded data.
func op(x Op) byte {
return byte(x)
}
// Bool returns op corresponding to bool value v.
func Bool(v bool) Op {
if v {
return True
} else {
return False
}
}
// u?intXSize(i) returns size needed to encode i.
func Int8Size (i int8) int { return Int64Size(int64(i)) }
func Int16Size(i int16) int { return Int64Size(int64(i)) }
func Int32Size(i int32) int { return Int64Size(int64(i)) }
func Uint8Size (i uint8) int { return Uint64Size(uint64(i)) }
func Uint16Size(i uint16) int { return Uint64Size(uint64(i)) }
func Uint32Size(i uint32) int { return Uint64Size(uint64(i)) }
// Putu?intX(data, i X) encodes i into data and returns encoded size.
func PutInt8 (data []byte, i int8) int { return PutInt64(data, int64(i)) }
func PutInt16(data []byte, i int16) int { return PutInt64(data, int64(i)) }
func PutInt32(data []byte, i int32) int { return PutInt64(data, int64(i)) }
func PutUint8 (data []byte, i uint8) int { return PutUint64(data, uint64(i)) }
func PutUint16(data []byte, i uint16) int { return PutUint64(data, uint64(i)) }
func PutUint32(data []byte, i uint32) int { return PutUint64(data, uint64(i)) }
func Int64Size(i int64) int {
switch {
case -32 <= i && i <= 0b0_1111111: return 1 // posfixint | negfixint
case int64(int8(i)) == i: return 1+1 // int8 + i8
case int64(int16(i)) == i: return 1+2 // int16 + i16
case int64(int32(i)) == i: return 1+4 // int32 + i32
default: return 1+8 // int64 + u64
}
}
func PutInt64(data []byte, i int64) (n int) {
switch {
// posfixint | negfixint
case -32 <= i && i <= 0b0_1111111:
data[0] = uint8(i)
return 1
// int8 + s8
case int64(int8(i)) == i:
data[0] = op(Int8)
data[1] = uint8(i)
return 1+1
// int16 + s16
case int64(int16(i)) == i:
data[0] = op(Int16)
binary.BigEndian.PutUint16(data[1:], uint16(i))
return 1+2
// int32 + s32
case int64(int32(i)) == i:
data[0] = op(Int32)
binary.BigEndian.PutUint32(data[1:], uint32(i))
return 1+4
// int64 + s64
default:
data[0] = op(Int64)
binary.BigEndian.PutUint64(data[1:], uint64(i))
return 1+8
}
}
func Uint64Size(i uint64) int {
switch {
case i <= 0x7f: return 1 // posfixint
case i <= 0xff: return 1+1 // uint8 + u8
case i <= 0xffff: return 1+2 // uint16 + u16
case i <= 0xffffffff: return 1+4 // uint32 + u32
default: return 1+8 // uint64 + u64
}
}
func PutUint64(data []byte, i uint64) (n int) {
switch {
// posfixint
case i <= 0x7f:
data[0] = uint8(i)
return 1
// uint8 + u8
case i <= math.MaxUint8:
data[0] = op(Uint8)
data[1] = uint8(i)
return 1+1
// uint16 + be16
case i <= math.MaxUint16:
data[0] = op(Uint16)
binary.BigEndian.PutUint16(data[1:], uint16(i))
return 1+2
// uint32 + be32
case i <= math.MaxUint32:
data[0] = op(Uint32)
binary.BigEndian.PutUint32(data[1:], uint32(i))
return 1+4
// uint64 + be64
default:
data[0] = op(Uint64)
binary.BigEndian.PutUint64(data[1:], i)
return 1+8
}
}
// BinHeadSize return number of bytes needed to encode header for [l]bin.
func BinHeadSize(l int) int {
switch {
case l < 0: panic("len < 0")
case l <= math.MaxUint8: return 1+1 // bin8 + len8
case l <= math.MaxUint16: return 1+2 // bin16 + len16
case l <= math.MaxUint32: return 1+4 // bin32 + len32
default: panic("len overflows uint32")
}
}
// PutBinHead puts binary header for [size]bin.
func PutBinHead(data []byte, l int) (n int) {
switch {
case l < 0: panic("len < 0")
// bin8 + len8
case l <= 0xff:
data[0] = op(Bin8)
data[1] = uint8(l)
return 1+1
// bin16 + len16
case l <= math.MaxUint16:
data[0] = op(Bin16)
binary.BigEndian.PutUint16(data[1:], uint16(l))
return 1+2
// bin32 + len32
case l <= math.MaxUint32:
data[0] = op(Bin32)
binary.BigEndian.PutUint32(data[1:], uint32(l))
return 1+4
default: panic("len overflows uint32")
}
}
// ArrayHeadSize returns size for array header for [size]array.
func ArrayHeadSize(l int) int {
switch {
case l < 0: panic("len < 0")
case l <= 0x0f: return 1 // fixarray
case l <= math.MaxUint16: return 1+2 // array16 + len16
case l <= math.MaxUint32: return 1+4 // array32 + len32
default: panic("len overflows uint32")
}
}
// PutArrayHead puts array header for [size]array.
func PutArrayHead(data []byte, l int) (n int) {
switch {
case l < 0: panic("len < 0")
// fixarray
case l <= 0x0f:
data[0] = op(FixArray_4 | Op(l))
return 1
// array16 + len16
case l <= math.MaxUint16:
data[0] = op(Array16)
binary.BigEndian.PutUint16(data[1:], uint16(l))
return 1+2
// array32 + len32
case l <= math.MaxUint32:
data[0] = op(Array32)
binary.BigEndian.PutUint32(data[1:], uint32(l))
return 1+4
default: panic("len overflows uint32")
}
}
// MapHeadSize returns size for map header for [size]map.
func MapHeadSize(l int) int {
return ArrayHeadSize(l) // the same 0x0f/len16/len32 scheme
}
// PutMapHead puts map header for [size]map.
func PutMapHead(data []byte, l int) (n int) {
switch {
case l < 0: panic("len < 0")
// fixmap
case l <= 0x0f:
data[0] = op(FixMap_4 | Op(l))
return 1
// map16 + len16
case l <= math.MaxUint16:
data[0] = op(Map16)
binary.BigEndian.PutUint16(data[1:], uint16(l))
return 1+2
// map32 + len32
case l <= math.MaxUint32:
data[0] = op(Map32)
binary.BigEndian.PutUint32(data[1:], uint32(l))
return 1+4
default: panic("len overflows uint32")
}
}
// Copyright (C) 2020-2021 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
// it under the terms of the GNU General Public License version 3, or (at your
// option) any later version, as published by the Free Software Foundation.
//
// You can also Link and Combine this program with other software covered by
// the terms of any of the Free Software licenses or any of the Open Source
// Initiative approved licenses and Convey the resulting work. Corresponding
// source of such a combination shall include the source code for all other
// software used.
//
// This program is distributed WITHOUT ANY WARRANTY; without even the implied
// warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
//
// See COPYING file for full licensing terms.
// See https://www.nexedi.com/licensing for rationale and options.
package msgpack
import (
hexpkg "encoding/hex"
"testing"
)
// hex decodes string as hex; panics on error.
func hex(s string) string {
b, err := hexpkg.DecodeString(s)
if err != nil {
panic(err)
}
return string(b)
}
// tGetPutSize is interface with Get/Put/Size methods, e.g. with
// GetBinHead/PutBinHead/BinHeadSize.
type tGetPutSize interface {
// XXX Get(data []byte) (n int, ret interface{})
Size(arg interface{}) int
Put(data []byte, arg interface{}) int
}
// test1 verifies enc functions on one argument.
func test1(t *testing.T, enc tGetPutSize, arg interface{}, encoded string) {
t.Helper()
data := make([]byte, 16)
n := enc.Put(data, arg)
got := string(data[:n])
if got != encoded {
t.Errorf("%v -> %x ; want %x", arg, got, encoded)
}
if sz := enc.Size(arg); sz != n {
t.Errorf("size(%v) -> %d ; len(data)=%d", arg, sz, n)
}
// XXX decode == arg, n
// XXX decode([:n-1]) -> overflow
}
type tEncUint64 struct{}
func (_ *tEncUint64) Size(xi interface{}) int { return Uint64Size(xi.(uint64)) }
func (_ *tEncUint64) Put(data []byte, xi interface{}) int { return PutUint64(data, xi.(uint64)) }
func TestUint(t *testing.T) {
h := hex
testv := []struct{i uint64; encoded string}{
{0, h("00")}, // posfixint
{1, h("01")},
{0x7f, h("7f")},
{0x80, h("cc80")}, // uint8
{0xff, h("ccff")},
{0x100, h("cd0100")}, // uint16
{0xffff, h("cdffff")},
{0x10000, h("ce00010000")}, // uint32
{0xffffffff, h("ceffffffff")},
{0x100000000, h("cf0000000100000000")}, // uint64
{0xffffffffffffffff, h("cfffffffffffffffff")},
}
for _, tt := range testv {
test1(t, &tEncUint64{}, tt.i, tt.encoded)
}
}
type tEncInt64 struct{}
func (_ *tEncInt64) Size(xi interface{}) int { return Int64Size(xi.(int64)) }
func (_ *tEncInt64) Put(data []byte, xi interface{}) int { return PutInt64(data, xi.(int64)) }
func TestInt(t *testing.T) {
h := hex
testv := []struct{i int64; encoded string}{
{0, h("00")}, // posfixint
{1, h("01")},
{0x7f, h("7f")},
{-1, h("ff")}, // negfixint
{-2, h("fe")},
{-31, h("e1")},
{-32, h("e0")},
{-33, h("d0df")}, // int8
{-0x7f, h("d081")},
{-0x80, h("d080")},
{0x80, h("d10080")}, // int16
{0x7fff, h("d17fff")},
{-0x7fff, h("d18001")},
{-0x8000, h("d18000")},
{0x8000, h("d200008000")}, // int32
{0x7fffffff, h("d27fffffff")},
{-0x8001, h("d2ffff7fff")},
{-0x7fffffff, h("d280000001")},
{-0x80000000, h("d280000000")},
{0x80000000, h("d30000000080000000")}, // int64
{0x7fffffffffffffff, h("d37fffffffffffffff")},
{-0x80000001, h("d3ffffffff7fffffff")},
{-0x7fffffffffffffff, h("d38000000000000001")},
{-0x8000000000000000, h("d38000000000000000")},
}
for _, tt := range testv {
test1(t, &tEncInt64{}, tt.i, tt.encoded)
}
}
type tEncBinHead struct{}
func (_ *tEncBinHead) Size(xl interface{}) int { return BinHeadSize(xl.(int)) }
func (_ *tEncBinHead) Put(data []byte, xl interface{}) int { return PutBinHead(data, xl.(int)) }
func TestBin(t *testing.T) {
h := hex
testv := []struct{l int; encoded string}{
{0, h("c400")}, // bin8
{1, h("c401")},
{0xff, h("c4ff")},
{0x100, h("c50100")}, // bin16
{0xffff, h("c5ffff")},
{0x10000, h("c600010000")}, // bin32
{0xffffffff, h("c6ffffffff")},
}
for _, tt := range testv {
test1(t, &tEncBinHead{}, tt.l, tt.encoded)
}
}
type tEncArrayHead struct{}
func (_ *tEncArrayHead) Size(xl interface{}) int { return ArrayHeadSize(xl.(int)) }
func (_ *tEncArrayHead) Put(data []byte, xl interface{}) int { return PutArrayHead(data, xl.(int)) }
func TestArray(t *testing.T) {
h := hex
testv := []struct{l int; encoded string}{
{0, h("90")}, // fixarray
{1, h("91")},
{14, h("9e")},
{15, h("9f")},
{0x10, h("dc0010")}, // array16
{0x11, h("dc0011")},
{0x100, h("dc0100")},
{0xffff, h("dcffff")},
{0x10000, h("dd00010000")}, // array32
{0xffffffff, h("ddffffffff")},
}
for _, tt := range testv {
test1(t, &tEncArrayHead{}, tt.l, tt.encoded)
}
}
type tEncMapHead struct{}
func (_ *tEncMapHead) Size(xl interface{}) int { return MapHeadSize(xl.(int)) }
func (_ *tEncMapHead) Put(data []byte, xl interface{}) int { return PutMapHead(data, xl.(int)) }
func TestMap(t *testing.T) {
h := hex
testv := []struct{l int; encoded string}{
{0, h("80")}, // fixmap
{1, h("81")},
{14, h("8e")},
{15, h("8f")},
{0x10, h("de0010")}, // map16
{0x11, h("de0011")},
{0x100, h("de0100")},
{0xffff, h("deffff")},
{0x10000, h("df00010000")}, // map32
{0xffffffff, h("dfffffffff")},
}
for _, tt := range testv {
test1(t, &tEncMapHead{}, tt.l, tt.encoded)
}
}
This diff is collapsed.
This diff is collapsed.
......@@ -21,18 +21,38 @@ package neonet
// link establishment
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"io"
"net"
"sync"
"os"
"github.com/philhofer/fwd"
"github.com/tinylib/msgp/msgp"
"lab.nexedi.com/kirr/go123/xerr"
"lab.nexedi.com/kirr/go123/xnet"
"lab.nexedi.com/kirr/neo/go/internal/xcontext"
"lab.nexedi.com/kirr/neo/go/internal/xio"
"lab.nexedi.com/kirr/neo/go/neo/proto"
)
// encDefault is default encoding to use.
// XXX we don't need this? (just set encDefault = 'M')
var encDefault = proto.Encoding('N') // XXX = 'M' instead?
func init() {
e := os.Getenv("NEO_ENCODING")
switch e {
case "": // not set
case "N": fallthrough
case "M": encDefault = proto.Encoding(e[0])
default:
fmt.Fprintf(os.Stderr, "E: $NEO_ENCODING=%q - invalid -> abort", e)
os.Exit(1)
}
}
// ---- Handshake ----
// XXX _Handshake may be needed to become public in case when we have already
......@@ -45,91 +65,215 @@ import (
// On success raw connection is returned wrapped into NodeLink.
// On error raw connection is closed.
func _Handshake(ctx context.Context, conn net.Conn, role _LinkRole) (nl *NodeLink, err error) {
err = handshake(ctx, conn, proto.Version)
enc := encDefault // default encoding
var rxbuf *fwd.Reader
switch role &^ linkFlagsMask {
case _LinkServer:
enc, rxbuf, err = handshakeServer(ctx, conn, proto.Version)
case _LinkClient:
enc, rxbuf, err = handshakeClient(ctx, conn, proto.Version, enc)
default:
panic("bug")
}
if err != nil {
return nil, err
}
// handshake ok -> NodeLink
return newNodeLink(conn, role), nil
return newNodeLink(conn, enc, role, rxbuf), nil
}
// _HandshakeError is returned when there is an error while performing handshake.
type _HandshakeError struct {
LocalRole _LinkRole
LocalAddr net.Addr
RemoteAddr net.Addr
Err error
}
func (e *_HandshakeError) Error() string {
return fmt.Sprintf("%s - %s: handshake: %s", e.LocalAddr, e.RemoteAddr, e.Err.Error())
role := ""
switch e.LocalRole {
case _LinkServer: role = "server"
case _LinkClient: role = "client"
default: panic("bug")
}
return fmt.Sprintf("%s - %s: handshake (%s): %s", e.LocalAddr, e.RemoteAddr, role, e.Err.Error())
}
func handshake(ctx context.Context, conn net.Conn, version uint32) (err error) {
// XXX simplify -> errgroup
errch := make(chan error, 2)
// tx handshake word
txWg := sync.WaitGroup{}
txWg.Add(1)
go func() {
var b [4]byte
binary.BigEndian.PutUint32(b[:], version) // XXX -> hton32 ?
_, err := conn.Write(b[:])
// XXX EOF -> ErrUnexpectedEOF ?
errch <- err
txWg.Done()
}()
func (e *_HandshakeError) Cause() error { return e.Err }
func (e *_HandshakeError) Unwrap() error { return e.Err }
// rx handshake word
go func() {
var b [4]byte
_, err := io.ReadFull(conn, b[:])
err = xio.NoEOF(err) // can be returned with n = 0
if err == nil {
peerVersion := binary.BigEndian.Uint32(b[:]) // XXX -> ntoh32 ?
if peerVersion != version {
err = fmt.Errorf("protocol version mismatch: peer = %08x ; our side = %08x", peerVersion, version)
}
// handshakeClient implements client-side handshake.
//
// Client indicates its version and preferred encoding, but accepts any
// encoding choosen to use by server.
func handshakeClient(ctx context.Context, conn net.Conn, version uint32, encPrefer proto.Encoding) (enc proto.Encoding, rxbuf *fwd.Reader, err error) {
defer func() {
if err != nil {
err = &_HandshakeError{_LinkClient, conn.LocalAddr(), conn.RemoteAddr(), err}
}
errch <- err
}()
connClosed := false
defer func() {
// make sure our version is always sent on the wire, if possible,
// so that peer does not see just closed connection when on rx we see version mismatch.
//
// NOTE if cancelled tx goroutine will wake up without delay.
txWg.Wait()
rxbuf = fwd.NewReader(conn)
// don't forget to close conn if returning with error + add handshake err context
var peerEnc proto.Encoding
err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error {
// tx client hello
err := txHello("tx hello", conn, version, encPrefer)
if err != nil {
err = &_HandshakeError{conn.LocalAddr(), conn.RemoteAddr(), err}
if !connClosed {
conn.Close()
}
return err
}
// rx server hello reply
var peerVer uint32
peerEnc, peerVer, err = rxHello("rx hello reply", rxbuf)
if err != nil {
return err
}
// verify version
if peerVer != version {
return fmt.Errorf("protocol version mismatch: peer = %08x ; our side = %08x", peerVer, version)
}
return nil
})
if err != nil {
return 0, nil, err
}
// use peer encoding (server should return the same, but we are ok if
// it asks to switch to different)
return peerEnc, rxbuf, nil
}
// handshakeServer implementss server-side handshake.
//
// Server verifies that its version matches Client and accepts client preferred encoding.
func handshakeServer(ctx context.Context, conn net.Conn, version uint32) (enc proto.Encoding, rxbuf *fwd.Reader, err error) {
defer func() {
if err != nil {
err = &_HandshakeError{_LinkServer, conn.LocalAddr(), conn.RemoteAddr(), err}
}
}()
for i := 0; i < 2; i++ {
select {
case <-ctx.Done():
conn.Close() // interrupt IO
connClosed = true
return ctx.Err()
case err = <-errch:
if err != nil {
return err
}
rxbuf = fwd.NewReader(conn)
var peerEnc proto.Encoding
err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error {
// rx client hello
var peerVer uint32
var err error
peerEnc, peerVer, err = rxHello("rx hello", rxbuf)
if err != nil {
return err
}
// tx server reply
//
// do it before version check so that client can also detect "version
// mismatch" instead of just getting "disconnect".
err = txHello("tx hello reply", conn, version, peerEnc)
if err != nil {
return err
}
// verify version
if peerVer != version {
return fmt.Errorf("protocol version mismatch: peer = %08x ; our side = %08x", peerVer, version)
}
return nil
})
if err != nil {
return 0, nil, err
}
return peerEnc, rxbuf, nil
}
func txHello(errctx string, conn net.Conn, version uint32, enc proto.Encoding) (err error) {
defer xerr.Context(&err, errctx)
var b []byte
switch enc {
case 'N':
// 00 00 00 <v>
b = make([]byte, 4)
if version > 0xff {
panic("encoding N supports versions only in range [0, 0xff]")
}
b[3] = uint8(version)
case 'M':
// (b"NEO", <V>) encoded as msgpack (= 92 c4 03 NEO int(<V>))
b = msgp.AppendArrayHeader(b, 2) // 92
b = msgp.AppendBytes(b, []byte("NEO")) // c4 03 NEO
b = msgp.AppendUint32(b, version) // u?intX version
default:
panic("bug")
}
_, err = conn.Write(b)
if err != nil {
return err
}
// handshaked ok
return nil
}
func rxHello(errctx string, rx *fwd.Reader) (enc proto.Encoding, version uint32, err error) {
defer xerr.Context(&err, errctx)
b := make([]byte, 4)
_, err = io.ReadFull(rx, b)
err = xio.NoEOF(err)
if err != nil {
return 0, 0, err
}
var peerEnc proto.Encoding
var peerVer uint32
badMagic := false
switch {
case bytes.Equal(b[:3], []byte{0,0,0}):
peerEnc = encN
peerVer = uint32(b[3])
case bytes.Equal(b, []byte{0x92, 0xc4, 3, 'N'}): // start of "fixarray<2> bin8 'N | EO' ...
b = append(b, []byte{0,0}...)
_, err = io.ReadFull(rx, b[4:])
err = xio.NoEOF(err)
if err != nil {
return 0, 0, err
}
if !bytes.Equal(b[4:], []byte{'E','O'}) {
badMagic = true
break
}
peerEnc = encM
rxM := msgp.Reader{R: rx}
peerVer, err = rxM.ReadUint32()
if err != nil {
return 0, 0, fmt.Errorf("M: recv peer version: %s", err) // XXX + "read magic" ctx
}
default:
badMagic = true
}
if badMagic {
return 0, 0, fmt.Errorf("invalid magic %x", b)
}
return peerEnc, peerVer, nil
}
// ---- Dial & Listen at NodeLink level ----
......@@ -141,6 +285,8 @@ func DialLink(ctx context.Context, net xnet.Networker, addr string) (*NodeLink,
return nil, err
}
// TODO if handshake fails with "closed" (= might be unexpected encoding)
// -> try redial and handshaking with different encoding (= autodetect encoding)
return _Handshake(ctx, peerConn, _LinkClient)
}
......
......@@ -21,29 +21,47 @@ package neonet
import (
"context"
"errors"
"io"
"net"
"testing"
"lab.nexedi.com/kirr/go123/exc"
"lab.nexedi.com/kirr/go123/xsync"
"lab.nexedi.com/kirr/neo/go/neo/proto"
)
func xhandshake(ctx context.Context, c net.Conn, version uint32) {
err := handshake(ctx, c, version)
// xhandshakeClient handshakes as client with encPrefer encoding and verifies that server accepts it.
func xhandshakeClient(ctx context.Context, c net.Conn, version uint32, encPrefer proto.Encoding) {
enc, _, err := handshakeClient(ctx, c, version, encPrefer)
exc.Raiseif(err)
if enc != encPrefer {
exc.Raisef("enc (%c) != encPrefer (%c)", enc, encPrefer)
}
}
// xhandshakeServer handshakes as server and verifies negotiated encoding to be encOK.
func xhandshakeServer(ctx context.Context, c net.Conn, version uint32, encOK proto.Encoding) {
enc, _, err := handshakeServer(ctx, c, version)
exc.Raiseif(err)
if enc != encOK {
exc.Raisef("enc (%c) != encOK (%c)", enc, encOK)
}
}
func TestHandshake(t *testing.T) {
Verify(t, _TestHandshake)
}
func _TestHandshake(t *T) {
bg := context.Background()
// handshake ok
p1, p2 := net.Pipe()
wg := xsync.NewWorkGroup(bg)
gox(wg, func(ctx context.Context) {
xhandshake(ctx, p1, 1)
xhandshakeClient(ctx, p1, 1, t.enc)
})
gox(wg, func(ctx context.Context) {
xhandshake(ctx, p2, 1)
xhandshakeServer(ctx, p2, 1, t.enc)
})
xwait(wg)
xclose(p1)
......@@ -54,17 +72,17 @@ func TestHandshake(t *testing.T) {
var err1, err2 error
wg = xsync.NewWorkGroup(bg)
gox(wg, func(ctx context.Context) {
err1 = handshake(ctx, p1, 1)
_, _, err1 = handshakeClient(ctx, p1, 1, t.enc)
})
gox(wg, func(ctx context.Context) {
err2 = handshake(ctx, p2, 2)
_, _, err2 = handshakeServer(ctx, p2, 2)
})
xwait(wg)
xclose(p1)
xclose(p2)
err1Want := "pipe - pipe: handshake: protocol version mismatch: peer = 00000002 ; our side = 00000001"
err2Want := "pipe - pipe: handshake: protocol version mismatch: peer = 00000001 ; our side = 00000002"
err1Want := "pipe - pipe: handshake (client): protocol version mismatch: peer = 00000002 ; our side = 00000001"
err2Want := "pipe - pipe: handshake (server): protocol version mismatch: peer = 00000001 ; our side = 00000002"
if !(err1 != nil && err1.Error() == err1Want) {
t.Errorf("handshake ver mismatch: p1: unexpected error:\nhave: %v\nwant: %v", err1, err1Want)
......@@ -78,7 +96,7 @@ func TestHandshake(t *testing.T) {
err1, err2 = nil, nil
wg = xsync.NewWorkGroup(bg)
gox(wg, func(ctx context.Context) {
err1 = handshake(ctx, p1, 1)
_, _, err1 = handshakeClient(ctx, p1, 1, t.enc)
})
gox(wg, func(_ context.Context) {
xclose(p2)
......@@ -88,16 +106,20 @@ func TestHandshake(t *testing.T) {
err11, ok := err1.(*_HandshakeError)
if !ok || !(err11.Err == io.ErrClosedPipe /* on Write */ || err11.Err == io.ErrUnexpectedEOF /* on Read */) {
if !ok || !(errors.Is(err11.Err, io.ErrClosedPipe /* on Write */) || errors.Is(err11.Err, io.ErrUnexpectedEOF /* on Read */)) {
t.Errorf("handshake peer close: unexpected error: %#v", err1)
}
// XXX same for handshakeServer
// ctx cancel
// XXX same for handshakeServer
p1, p2 = net.Pipe()
ctx, cancel := context.WithCancel(bg)
wg = xsync.NewWorkGroup(ctx)
gox(wg, func(ctx context.Context) {
err1 = handshake(ctx, p1, 1)
_, _, err1 = handshakeClient(ctx, p1, 1, t.enc)
})
tdelay()
cancel()
......@@ -110,5 +132,4 @@ func TestHandshake(t *testing.T) {
if !ok || !(err11.Err == context.Canceled) {
t.Errorf("handshake cancel: unexpected error: %#v", err1)
}
}
......@@ -39,15 +39,17 @@ type pktBuf struct {
data []byte // whole packet data including all headers
}
// Header returns pointer to packet header.
func (pkt *pktBuf) Header() *proto.PktHeader {
// HeaderN returns pointer to packet header in 'N'-encoding.
func (pkt *pktBuf) Header() *proto.PktHeader { return pkt.HeaderN() } // XXX kill
func (pkt *pktBuf) HeaderN() *proto.PktHeader {
// NOTE no need to check len(.data) < PktHeader:
// .data is always allocated with cap >= PktHeaderLen.
return (*proto.PktHeader)(unsafe.Pointer(&pkt.data[0]))
}
// Payload returns []byte representing packet payload.
func (pkt *pktBuf) Payload() []byte {
// PayloadN returns []byte representing packet payload in 'N'-encoding.
func (pkt *pktBuf) Payload() []byte { return pkt.PayloadN() } // XXX kill
func (pkt *pktBuf) PayloadN() []byte {
return pkt.data[proto.PktHeaderLen:]
}
......@@ -87,6 +89,7 @@ func (pkt *pktBuf) String() string {
h := pkt.Header()
s := fmt.Sprintf(".%d", packed.Ntoh32(h.ConnId))
// XXX encN-specific
msgCode := packed.Ntoh16(h.MsgCode)
msgLen := packed.Ntoh32(h.MsgLen)
data := pkt.Payload()
......@@ -98,7 +101,7 @@ func (pkt *pktBuf) String() string {
// XXX dup wrt Conn.Recv
msg := reflect.New(msgType).Interface().(proto.Msg)
n, err := msg.NEOMsgDecode(data)
n, err := encN.NEOMsgDecode(msg, data) // XXX encN hardcoded
if err != nil {
s += fmt.Sprintf(" (%s) %v; #%d [%d]: % x", msgType.Name(), err, msgLen, len(data), data)
} else {
......
// Copyright (C) 2020-2021 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
// it under the terms of the GNU General Public License version 3, or (at your
// option) any later version, as published by the Free Software Foundation.
//
// You can also Link and Combine this program with other software covered by
// the terms of any of the Free Software licenses or any of the Open Source
// Initiative approved licenses and Convey the resulting work. Corresponding
// source of such a combination shall include the source code for all other
// software used.
//
// This program is distributed WITHOUT ANY WARRANTY; without even the implied
// warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
//
// See COPYING file for full licensing terms.
// See https://www.nexedi.com/licensing for rationale and options.
package proto
// runtime glue for msgpack support
import (
"fmt"
"github.com/tinylib/msgp/msgp"
"lab.nexedi.com/kirr/neo/go/neo/internal/msgpack"
)
// mstructDecodeError represents decode error when decoder was expecting
// tuple<nfield> for structure named path.
type mstructDecodeError struct {
path string // "Type.field.field"
op msgpack.Op // op we got
opOk msgpack.Op // op expected
}
func (e *mstructDecodeError) Error() string {
return fmt.Sprintf("decode: M: struct %s: got opcode %02x; expect %02x", e.path, e.op, e.opOk)
}
// mdecodeErr is called to normilize error when msgp.ReadXXX returns err when decoding path.
func mdecodeErr(path string, err error) error {
if err == msgp.ErrShortBytes {
return ErrDecodeOverflow
}
return &mdecodeError{path, err}
}
type mdecodeError struct {
path string // "Type.field.field"
err error
}
func (e *mdecodeError) Error() string {
return fmt.Sprintf("decode: M: %s: %s", e.path, e.err)
}
// mOpError represents decode error when decoder faces unexpected operation.
type mOpError struct {
op, opOk msgpack.Op // op we got and what was expected
}
func (e *mOpError) Error() string {
return fmt.Sprintf("expected opcode %02x; got %02x", e.opOk, e.op)
}
func mdecodeOpErr(path string, op, opOk msgpack.Op) error {
return mdecodeErr(path+"/op", &mOpError{op, opOk})
}
// mLen8Error represents decode error when decoder faces unexpected length in Bin8.
type mLen8Error struct {
l, lOk byte // len we got and expected
}
func (e *mLen8Error) Error() string {
return fmt.Sprintf("expected length %d; got %d", e.lOk, e.l)
}
func mdecodeLen8Err(path string, l, lOk uint8) error {
return mdecodeErr(path+"/len", &mLen8Error{l, lOk})
}
func mdecodeEnumTypeErr(path string, enumType, enumTypeOk byte) error {
return mdecodeErr(path+"/enumType",
fmt.Errorf("expected %d; got %d", enumTypeOk, enumType))
}
func mdecodeEnumValueErr(path string, v byte) error {
return mdecodeErr(path, fmt.Errorf("invalid enum payload %02x", v))
}
......@@ -32,6 +32,12 @@ import (
"time"
)
// MsgCode returns code the corresponds to type of the message.
// XXX place - ok?
func MsgCode(msg Msg) uint16 {
return msg.neoMsgCode()
}
// MsgType looks up message type by message code.
//
// Nil is returned if message code is not valid.
......
......@@ -84,6 +84,7 @@ const (
Version = 6
// length of packet header
// XXX encN-specific ?
PktHeaderLen = 10 // = unsafe.Sizeof(PktHeader{}), but latter gives typed constant (uintptr)
// packets larger than PktMaxSize are not allowed.
......@@ -99,6 +100,7 @@ const (
INVALID_OID zodb.Oid = 1<<64 - 1
)
// XXX encN-specific ?
// PktHeader represents header of a raw packet.
//
// A packet contains connection ID and message.
......@@ -110,31 +112,75 @@ type PktHeader struct {
MsgLen packed.BE32 // payload message length (excluding packet header)
}
// Msg is the interface implemented by all NEO messages.
// Msg is the interface representing a NEO message.
type Msg interface {
// marshal/unmarshal into/from wire format:
// NEOMsgCode returns message code needed to be used for particular message type
// neoMsgCode returns message code needed to be used for particular message type
// on the wire.
NEOMsgCode() uint16
neoMsgCode() uint16
// NEOMsgEncodedLen returns how much space is needed to encode current message payload.
NEOMsgEncodedLen() int
// NEOMsgEncode encodes current message state into buf.
// for encoding E:
//
// - neoMsgEncodedLen<E> returns how much space is needed to encode current message payload via E encoding.
//
// - neoMsgEncode<E> encodes current message state into buf via E encoding.
//
// len(buf) must be >= neoMsgEncodedLen().
NEOMsgEncode(buf []byte)
// len(buf) must be >= neoMsgEncodedLen<E>().
//
// - neoMsgDecode<E> decodes data via E encoding into message in-place.
// N encoding (original struct-based encoding)
neoMsgEncodedLenN() int
neoMsgEncodeN(buf []byte)
neoMsgDecodeN(data []byte) (nread int, err error)
// NEOMsgDecode decodes data into message in-place.
NEOMsgDecode(data []byte) (nread int, err error)
// M encoding (via MessagePack)
neoMsgEncodedLenM() int
neoMsgEncodeM(buf []byte)
neoMsgDecodeM(data []byte) (nread int, err error)
}
// Encoding represents messages encoding.
type Encoding byte
// XXX drop "NEO" prefix?
// NEOMsgEncodedLen returns how much space is needed to encode msg payload via encoding e.
func (e Encoding) NEOMsgEncodedLen(msg Msg) int {
switch e {
default: panic("bug")
case 'N': return msg.neoMsgEncodedLenN()
case 'M': return msg.neoMsgEncodedLenM()
}
}
// NEOMsgEncode encodes msg state into buf via encoding e.
//
// len(buf) must be >= e.NEOMsgEncodedLen(m).
func (e Encoding) NEOMsgEncode(msg Msg, buf []byte) {
switch e {
default: panic("bug")
case 'N': msg.neoMsgEncodeN(buf)
case 'M': msg.neoMsgEncodeM(buf)
}
}
// NEOMsgDecode decodes data via encoding e into msg in-place.
func (e Encoding) NEOMsgDecode(msg Msg, data []byte) (nread int, err error) {
switch e {
default: panic("bug")
case 'N': return msg.neoMsgDecodeN(data)
case 'M': return msg.neoMsgDecodeM(data)
}
}
// ErrDecodeOverflow is the error returned by neoMsgDecode when decoding hits buffer overflow
var ErrDecodeOverflow = errors.New("decode: buffer overflow")
// ---- messages ----
//neo:proto enum
type ErrorCode uint32
const (
ACK ErrorCode = iota
......@@ -155,6 +201,7 @@ const (
// XXX move this to neo.clusterState wrapping proto.ClusterState?
//trace:event traceClusterStateChanged(cs *ClusterState)
//neo:proto enum
type ClusterState int8
const (
// The cluster is initially in the RECOVERING state, and it goes back to
......@@ -188,6 +235,7 @@ const (
STOPPING_BACKUP
)
//neo:proto enum
type NodeType int8
const (
MASTER NodeType = iota
......@@ -196,6 +244,7 @@ const (
ADMIN
)
//neo:proto enum
type NodeState int8
const (
UNKNOWN NodeState = iota //short: U // XXX tag prefix name ?
......@@ -204,6 +253,7 @@ const (
PENDING //short: P
)
//neo:proto enum
type CellState int8
const (
// Write-only cell. Last transactions are missing because storage is/was down
......@@ -255,7 +305,7 @@ type Address struct {
}
// NOTE if Host == "" -> Port not added to wire (see py.PAddress):
func (a *Address) neoEncodedLen() int {
func (a *Address) neoEncodedLenN() int {
l := string_neoEncodedLen(a.Host)
if a.Host != "" {
l += 2
......@@ -263,7 +313,7 @@ func (a *Address) neoEncodedLen() int {
return l
}
func (a *Address) neoEncode(b []byte) int {
func (a *Address) neoEncodeN(b []byte) int {
n := string_neoEncode(a.Host, b[0:])
if a.Host != "" {
binary.BigEndian.PutUint16(b[n:], a.Port)
......@@ -272,7 +322,7 @@ func (a *Address) neoEncode(b []byte) int {
return n
}
func (a *Address) neoDecode(b []byte) (uint64, bool) {
func (a *Address) neoDecodeN(b []byte) (uint64, bool) {
n, ok := string_neoDecode(&a.Host, b)
if !ok {
return 0, false
......@@ -295,17 +345,17 @@ type Checksum [20]byte
// PTid is Partition Table identifier.
//
// Zero value means "invalid id" (<-> None in py.PPTID)
// Zero value means "invalid id" (<-> None in py.PPTID) XXX = nil in msgpack
type PTid uint64
// IdTime represents time of identification.
type IdTime float64
func (t IdTime) neoEncodedLen() int {
func (t IdTime) neoEncodedLenN() int {
return 8
}
func (t IdTime) neoEncode(b []byte) int {
func (t IdTime) neoEncodeN(b []byte) int {
// use -inf as value for no data (NaN != NaN -> hard to use NaN in tests)
// NOTE neo/py uses None for "no data"; we use 0 for "no data" to avoid pointer
tt := float64(t)
......@@ -316,7 +366,7 @@ func (t IdTime) neoEncode(b []byte) int {
return 8
}
func (t *IdTime) neoDecode(data []byte) (uint64, bool) {
func (t *IdTime) neoDecodeN(data []byte) (uint64, bool) {
if len(data) < 8 {
return 0, false
}
......@@ -438,8 +488,8 @@ type Recovery struct {
type AnswerRecovery struct {
PTid
BackupTid zodb.Tid
TruncateTid zodb.Tid
BackupTid zodb.Tid // XXX nil <-> 0
TruncateTid zodb.Tid // XXX nil <-> 0
}
// Ask the last OID/TID so that a master can initialize its TransactionManager.
......@@ -1199,13 +1249,13 @@ type FlushLog struct {}
// ---- runtime support for protogen and custom codecs ----
// customCodec is the interface that is implemented by types with custom encodings.
// customCodecN is the interface that is implemented by types with custom N encodings.
//
// its semantic is very similar to Msg.
type customCodec interface {
neoEncodedLen() int
neoEncode(buf []byte) (nwrote int)
neoDecode(data []byte) (nread uint64, ok bool) // XXX uint64 or int here?
type customCodecN interface {
neoEncodedLenN() int
neoEncodeN(buf []byte) (nwrote int)
neoDecodeN(data []byte) (nread uint64, ok bool) // XXX uint64 or int here?
}
func byte2bool(b byte) bool {
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -387,6 +387,8 @@ func Verify(t *testing.T, f func(*tEnv)) {
// TODO verify M=(go|py) x S=(go|py) x ...
// for now we only verify for all combinations of network
// TODO verify enc=(M|N)
// for all networks
for _, network := range []string{"pipenet", "lonet"} {
opt := tClusterOptions{
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment