Commit 1cf2712f authored by Kirill Smelkov's avatar Kirill Smelkov

X on msgpack support

parent d2697535
...@@ -27,6 +27,7 @@ package xcontext ...@@ -27,6 +27,7 @@ package xcontext
import ( import (
"context" "context"
"errors" "errors"
"io"
) )
// Cancelled reports whether an error is due to a canceled context. // Cancelled reports whether an error is due to a canceled context.
...@@ -72,3 +73,41 @@ func WhenDone(ctx context.Context, f func()) func() { ...@@ -72,3 +73,41 @@ func WhenDone(ctx context.Context, f func()) func() {
close(done) close(done)
} }
} }
// WithCloseOnErrCancel closes c on ctx cancel while f is run, or if f returns with an error.
//
// It is usually handy to propagate cancellation to interrupt IO.
// XXX naming?
// XXX don't close on f return?
func WithCloseOnErrCancel(ctx context.Context, c io.Closer, f func() error) (err error) {
closed := false
fdone := make(chan error)
defer func() {
<-fdone // wait for f to complete
if err != nil {
if !closed {
c.Close()
}
}
}()
go func() (err error) {
defer func() {
fdone <- err
close(fdone)
}()
return f()
}()
select {
case <-ctx.Done():
c.Close() // interrupt IO
closed = true
return ctx.Err()
case err := <-fdone:
return err
}
}
...@@ -489,6 +489,7 @@ func withNEO(t *testing.T, f func(t *testing.T, nsrv NEOSrv, ndrv *Client), optv ...@@ -489,6 +489,7 @@ func withNEO(t *testing.T, f func(t *testing.T, nsrv NEOSrv, ndrv *Client), optv
withNEOSrv(t, func(t *testing.T, nsrv NEOSrv) { withNEOSrv(t, func(t *testing.T, nsrv NEOSrv) {
t.Helper() t.Helper()
X := xtesting.FatalIf(t) X := xtesting.FatalIf(t)
// TODO test for enc=(M|N) (XXX M|N only for NEO/go as NEO/py does not support autodetect)
ndrv, _, err := neoOpen(nsrv.URL(), ndrv, _, err := neoOpen(nsrv.URL(),
&zodb.DriverOptions{ReadOnly: true}); X(err) &zodb.DriverOptions{ReadOnly: true}); X(err)
defer func() { defer func() {
......
// Copyright (C) 2020-2021 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
// it under the terms of the GNU General Public License version 3, or (at your
// option) any later version, as published by the Free Software Foundation.
//
// You can also Link and Combine this program with other software covered by
// the terms of any of the Free Software licenses or any of the Open Source
// Initiative approved licenses and Convey the resulting work. Corresponding
// source of such a combination shall include the source code for all other
// software used.
//
// This program is distributed WITHOUT ANY WARRANTY; without even the implied
// warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
//
// See COPYING file for full licensing terms.
// See https://www.nexedi.com/licensing for rationale and options.
// Package msgpack complements tinylib/msgp in providing runtime support for MessagePack.
//
// https://github.com/msgpack/msgpack/blob/master/spec.md
package msgpack
import (
"encoding/binary"
"math"
)
// Op represents a MessagePack opcode.
type Op byte
const (
FixMap_4 Op = 0b1000_0000 // 1000_XXXX
FixArray_4 Op = 0b1001_0000 // 1001_XXXX
False Op = 0xc2
True Op = 0xc3
Bin8 Op = 0xc4
Bin16 Op = 0xc5
Bin32 Op = 0xc6
Float32 Op = 0xca
Float64 Op = 0xcb
Uint8 Op = 0xcc
Uint16 Op = 0xcd
Uint32 Op = 0xce
Uint64 Op = 0xcf
Int8 Op = 0xd0
Int16 Op = 0xd1
Int32 Op = 0xd2
Int64 Op = 0xd3
FixExt1 Op = 0xd4
FixExt2 Op = 0xd5
FixExt4 Op = 0xd6
Array16 Op = 0xdc
Array32 Op = 0xdd
Map16 Op = 0xde
Map32 Op = 0xdf
)
// op converts Op into byte.
// it is used internally to make sure that only Op is put into encoded data.
func op(x Op) byte {
return byte(x)
}
// Bool returns op corresponding to bool value v.
func Bool(v bool) Op {
if v {
return True
} else {
return False
}
}
// u?intXSize(i) returns size needed to encode i.
func Int8Size (i int8) int { return Int64Size(int64(i)) }
func Int16Size(i int16) int { return Int64Size(int64(i)) }
func Int32Size(i int32) int { return Int64Size(int64(i)) }
func Uint8Size (i uint8) int { return Uint64Size(uint64(i)) }
func Uint16Size(i uint16) int { return Uint64Size(uint64(i)) }
func Uint32Size(i uint32) int { return Uint64Size(uint64(i)) }
// Putu?intX(data, i X) encodes i into data and returns encoded size.
func PutInt8 (data []byte, i int8) int { return PutInt64(data, int64(i)) }
func PutInt16(data []byte, i int16) int { return PutInt64(data, int64(i)) }
func PutInt32(data []byte, i int32) int { return PutInt64(data, int64(i)) }
func PutUint8 (data []byte, i uint8) int { return PutUint64(data, uint64(i)) }
func PutUint16(data []byte, i uint16) int { return PutUint64(data, uint64(i)) }
func PutUint32(data []byte, i uint32) int { return PutUint64(data, uint64(i)) }
func Int64Size(i int64) int {
switch {
case -32 <= i && i <= 0b0_1111111: return 1 // posfixint | negfixint
case int64(int8(i)) == i: return 1+1 // int8 + i8
case int64(int16(i)) == i: return 1+2 // int16 + i16
case int64(int32(i)) == i: return 1+4 // int32 + i32
default: return 1+8 // int64 + u64
}
}
func PutInt64(data []byte, i int64) (n int) {
switch {
// posfixint | negfixint
case -32 <= i && i <= 0b0_1111111:
data[0] = uint8(i)
return 1
// int8 + s8
case int64(int8(i)) == i:
data[0] = op(Int8)
data[1] = uint8(i)
return 1+1
// int16 + s16
case int64(int16(i)) == i:
data[0] = op(Int16)
binary.BigEndian.PutUint16(data[1:], uint16(i))
return 1+2
// int32 + s32
case int64(int32(i)) == i:
data[0] = op(Int32)
binary.BigEndian.PutUint32(data[1:], uint32(i))
return 1+4
// int64 + s64
default:
data[0] = op(Int64)
binary.BigEndian.PutUint64(data[1:], uint64(i))
return 1+8
}
}
func Uint64Size(i uint64) int {
switch {
case i <= 0x7f: return 1 // posfixint
case i <= 0xff: return 1+1 // uint8 + u8
case i <= 0xffff: return 1+2 // uint16 + u16
case i <= 0xffffffff: return 1+4 // uint32 + u32
default: return 1+8 // uint64 + u64
}
}
func PutUint64(data []byte, i uint64) (n int) {
switch {
// posfixint
case i <= 0x7f:
data[0] = uint8(i)
return 1
// uint8 + u8
case i <= math.MaxUint8:
data[0] = op(Uint8)
data[1] = uint8(i)
return 1+1
// uint16 + be16
case i <= math.MaxUint16:
data[0] = op(Uint16)
binary.BigEndian.PutUint16(data[1:], uint16(i))
return 1+2
// uint32 + be32
case i <= math.MaxUint32:
data[0] = op(Uint32)
binary.BigEndian.PutUint32(data[1:], uint32(i))
return 1+4
// uint64 + be64
default:
data[0] = op(Uint64)
binary.BigEndian.PutUint64(data[1:], i)
return 1+8
}
}
// BinHeadSize return number of bytes needed to encode header for [l]bin.
func BinHeadSize(l int) int {
switch {
case l < 0: panic("len < 0")
case l <= math.MaxUint8: return 1+1 // bin8 + len8
case l <= math.MaxUint16: return 1+2 // bin16 + len16
case l <= math.MaxUint32: return 1+4 // bin32 + len32
default: panic("len overflows uint32")
}
}
// PutBinHead puts binary header for [size]bin.
func PutBinHead(data []byte, l int) (n int) {
switch {
case l < 0: panic("len < 0")
// bin8 + len8
case l <= 0xff:
data[0] = op(Bin8)
data[1] = uint8(l)
return 1+1
// bin16 + len16
case l <= math.MaxUint16:
data[0] = op(Bin16)
binary.BigEndian.PutUint16(data[1:], uint16(l))
return 1+2
// bin32 + len32
case l <= math.MaxUint32:
data[0] = op(Bin32)
binary.BigEndian.PutUint32(data[1:], uint32(l))
return 1+4
default: panic("len overflows uint32")
}
}
// ArrayHeadSize returns size for array header for [size]array.
func ArrayHeadSize(l int) int {
switch {
case l < 0: panic("len < 0")
case l <= 0x0f: return 1 // fixarray
case l <= math.MaxUint16: return 1+2 // array16 + len16
case l <= math.MaxUint32: return 1+4 // array32 + len32
default: panic("len overflows uint32")
}
}
// PutArrayHead puts array header for [size]array.
func PutArrayHead(data []byte, l int) (n int) {
switch {
case l < 0: panic("len < 0")
// fixarray
case l <= 0x0f:
data[0] = op(FixArray_4 | Op(l))
return 1
// array16 + len16
case l <= math.MaxUint16:
data[0] = op(Array16)
binary.BigEndian.PutUint16(data[1:], uint16(l))
return 1+2
// array32 + len32
case l <= math.MaxUint32:
data[0] = op(Array32)
binary.BigEndian.PutUint32(data[1:], uint32(l))
return 1+4
default: panic("len overflows uint32")
}
}
// MapHeadSize returns size for map header for [size]map.
func MapHeadSize(l int) int {
return ArrayHeadSize(l) // the same 0x0f/len16/len32 scheme
}
// PutMapHead puts map header for [size]map.
func PutMapHead(data []byte, l int) (n int) {
switch {
case l < 0: panic("len < 0")
// fixmap
case l <= 0x0f:
data[0] = op(FixMap_4 | Op(l))
return 1
// map16 + len16
case l <= math.MaxUint16:
data[0] = op(Map16)
binary.BigEndian.PutUint16(data[1:], uint16(l))
return 1+2
// map32 + len32
case l <= math.MaxUint32:
data[0] = op(Map32)
binary.BigEndian.PutUint32(data[1:], uint32(l))
return 1+4
default: panic("len overflows uint32")
}
}
// Copyright (C) 2020-2021 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
// it under the terms of the GNU General Public License version 3, or (at your
// option) any later version, as published by the Free Software Foundation.
//
// You can also Link and Combine this program with other software covered by
// the terms of any of the Free Software licenses or any of the Open Source
// Initiative approved licenses and Convey the resulting work. Corresponding
// source of such a combination shall include the source code for all other
// software used.
//
// This program is distributed WITHOUT ANY WARRANTY; without even the implied
// warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
//
// See COPYING file for full licensing terms.
// See https://www.nexedi.com/licensing for rationale and options.
package msgpack
import (
hexpkg "encoding/hex"
"testing"
)
// hex decodes string as hex; panics on error.
func hex(s string) string {
b, err := hexpkg.DecodeString(s)
if err != nil {
panic(err)
}
return string(b)
}
// tGetPutSize is interface with Get/Put/Size methods, e.g. with
// GetBinHead/PutBinHead/BinHeadSize.
type tGetPutSize interface {
// XXX Get(data []byte) (n int, ret interface{})
Size(arg interface{}) int
Put(data []byte, arg interface{}) int
}
// test1 verifies enc functions on one argument.
func test1(t *testing.T, enc tGetPutSize, arg interface{}, encoded string) {
t.Helper()
data := make([]byte, 16)
n := enc.Put(data, arg)
got := string(data[:n])
if got != encoded {
t.Errorf("%v -> %x ; want %x", arg, got, encoded)
}
if sz := enc.Size(arg); sz != n {
t.Errorf("size(%v) -> %d ; len(data)=%d", arg, sz, n)
}
// XXX decode == arg, n
// XXX decode([:n-1]) -> overflow
}
type tEncUint64 struct{}
func (_ *tEncUint64) Size(xi interface{}) int { return Uint64Size(xi.(uint64)) }
func (_ *tEncUint64) Put(data []byte, xi interface{}) int { return PutUint64(data, xi.(uint64)) }
func TestUint(t *testing.T) {
h := hex
testv := []struct{i uint64; encoded string}{
{0, h("00")}, // posfixint
{1, h("01")},
{0x7f, h("7f")},
{0x80, h("cc80")}, // uint8
{0xff, h("ccff")},
{0x100, h("cd0100")}, // uint16
{0xffff, h("cdffff")},
{0x10000, h("ce00010000")}, // uint32
{0xffffffff, h("ceffffffff")},
{0x100000000, h("cf0000000100000000")}, // uint64
{0xffffffffffffffff, h("cfffffffffffffffff")},
}
for _, tt := range testv {
test1(t, &tEncUint64{}, tt.i, tt.encoded)
}
}
type tEncInt64 struct{}
func (_ *tEncInt64) Size(xi interface{}) int { return Int64Size(xi.(int64)) }
func (_ *tEncInt64) Put(data []byte, xi interface{}) int { return PutInt64(data, xi.(int64)) }
func TestInt(t *testing.T) {
h := hex
testv := []struct{i int64; encoded string}{
{0, h("00")}, // posfixint
{1, h("01")},
{0x7f, h("7f")},
{-1, h("ff")}, // negfixint
{-2, h("fe")},
{-31, h("e1")},
{-32, h("e0")},
{-33, h("d0df")}, // int8
{-0x7f, h("d081")},
{-0x80, h("d080")},
{0x80, h("d10080")}, // int16
{0x7fff, h("d17fff")},
{-0x7fff, h("d18001")},
{-0x8000, h("d18000")},
{0x8000, h("d200008000")}, // int32
{0x7fffffff, h("d27fffffff")},
{-0x8001, h("d2ffff7fff")},
{-0x7fffffff, h("d280000001")},
{-0x80000000, h("d280000000")},
{0x80000000, h("d30000000080000000")}, // int64
{0x7fffffffffffffff, h("d37fffffffffffffff")},
{-0x80000001, h("d3ffffffff7fffffff")},
{-0x7fffffffffffffff, h("d38000000000000001")},
{-0x8000000000000000, h("d38000000000000000")},
}
for _, tt := range testv {
test1(t, &tEncInt64{}, tt.i, tt.encoded)
}
}
type tEncBinHead struct{}
func (_ *tEncBinHead) Size(xl interface{}) int { return BinHeadSize(xl.(int)) }
func (_ *tEncBinHead) Put(data []byte, xl interface{}) int { return PutBinHead(data, xl.(int)) }
func TestBin(t *testing.T) {
h := hex
testv := []struct{l int; encoded string}{
{0, h("c400")}, // bin8
{1, h("c401")},
{0xff, h("c4ff")},
{0x100, h("c50100")}, // bin16
{0xffff, h("c5ffff")},
{0x10000, h("c600010000")}, // bin32
{0xffffffff, h("c6ffffffff")},
}
for _, tt := range testv {
test1(t, &tEncBinHead{}, tt.l, tt.encoded)
}
}
type tEncArrayHead struct{}
func (_ *tEncArrayHead) Size(xl interface{}) int { return ArrayHeadSize(xl.(int)) }
func (_ *tEncArrayHead) Put(data []byte, xl interface{}) int { return PutArrayHead(data, xl.(int)) }
func TestArray(t *testing.T) {
h := hex
testv := []struct{l int; encoded string}{
{0, h("90")}, // fixarray
{1, h("91")},
{14, h("9e")},
{15, h("9f")},
{0x10, h("dc0010")}, // array16
{0x11, h("dc0011")},
{0x100, h("dc0100")},
{0xffff, h("dcffff")},
{0x10000, h("dd00010000")}, // array32
{0xffffffff, h("ddffffffff")},
}
for _, tt := range testv {
test1(t, &tEncArrayHead{}, tt.l, tt.encoded)
}
}
type tEncMapHead struct{}
func (_ *tEncMapHead) Size(xl interface{}) int { return MapHeadSize(xl.(int)) }
func (_ *tEncMapHead) Put(data []byte, xl interface{}) int { return PutMapHead(data, xl.(int)) }
func TestMap(t *testing.T) {
h := hex
testv := []struct{l int; encoded string}{
{0, h("80")}, // fixmap
{1, h("81")},
{14, h("8e")},
{15, h("8f")},
{0x10, h("de0010")}, // map16
{0x11, h("de0011")},
{0x100, h("de0100")},
{0xffff, h("deffff")},
{0x10000, h("df00010000")}, // map32
{0xffffffff, h("dfffffffff")},
}
for _, tt := range testv {
test1(t, &tEncMapHead{}, tt.l, tt.encoded)
}
}
This diff is collapsed.
This diff is collapsed.
...@@ -21,18 +21,38 @@ package neonet ...@@ -21,18 +21,38 @@ package neonet
// link establishment // link establishment
import ( import (
"bytes"
"context" "context"
"encoding/binary"
"fmt" "fmt"
"io" "io"
"net" "net"
"sync" "os"
"github.com/philhofer/fwd"
"github.com/tinylib/msgp/msgp"
"lab.nexedi.com/kirr/go123/xerr"
"lab.nexedi.com/kirr/go123/xnet" "lab.nexedi.com/kirr/go123/xnet"
"lab.nexedi.com/kirr/neo/go/internal/xcontext"
"lab.nexedi.com/kirr/neo/go/internal/xio" "lab.nexedi.com/kirr/neo/go/internal/xio"
"lab.nexedi.com/kirr/neo/go/neo/proto" "lab.nexedi.com/kirr/neo/go/neo/proto"
) )
// encDefault is default encoding to use.
// XXX we don't need this? (just set encDefault = 'M')
var encDefault = proto.Encoding('N') // XXX = 'M' instead?
func init() {
e := os.Getenv("NEO_ENCODING")
switch e {
case "": // not set
case "N": fallthrough
case "M": encDefault = proto.Encoding(e[0])
default:
fmt.Fprintf(os.Stderr, "E: $NEO_ENCODING=%q - invalid -> abort", e)
os.Exit(1)
}
}
// ---- Handshake ---- // ---- Handshake ----
// XXX _Handshake may be needed to become public in case when we have already // XXX _Handshake may be needed to become public in case when we have already
...@@ -45,89 +65,213 @@ import ( ...@@ -45,89 +65,213 @@ import (
// On success raw connection is returned wrapped into NodeLink. // On success raw connection is returned wrapped into NodeLink.
// On error raw connection is closed. // On error raw connection is closed.
func _Handshake(ctx context.Context, conn net.Conn, role _LinkRole) (nl *NodeLink, err error) { func _Handshake(ctx context.Context, conn net.Conn, role _LinkRole) (nl *NodeLink, err error) {
err = handshake(ctx, conn, proto.Version) enc := encDefault // default encoding
var rxbuf *fwd.Reader
switch role &^ linkFlagsMask {
case _LinkServer:
enc, rxbuf, err = handshakeServer(ctx, conn, proto.Version)
case _LinkClient:
enc, rxbuf, err = handshakeClient(ctx, conn, proto.Version, enc)
default:
panic("bug")
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
// handshake ok -> NodeLink // handshake ok -> NodeLink
return newNodeLink(conn, role), nil return newNodeLink(conn, enc, role, rxbuf), nil
} }
// _HandshakeError is returned when there is an error while performing handshake. // _HandshakeError is returned when there is an error while performing handshake.
type _HandshakeError struct { type _HandshakeError struct {
LocalRole _LinkRole
LocalAddr net.Addr LocalAddr net.Addr
RemoteAddr net.Addr RemoteAddr net.Addr
Err error Err error
} }
func (e *_HandshakeError) Error() string { func (e *_HandshakeError) Error() string {
return fmt.Sprintf("%s - %s: handshake: %s", e.LocalAddr, e.RemoteAddr, e.Err.Error()) role := ""
switch e.LocalRole {
case _LinkServer: role = "server"
case _LinkClient: role = "client"
default: panic("bug")
}
return fmt.Sprintf("%s - %s: handshake (%s): %s", e.LocalAddr, e.RemoteAddr, role, e.Err.Error())
} }
func handshake(ctx context.Context, conn net.Conn, version uint32) (err error) { func (e *_HandshakeError) Cause() error { return e.Err }
// XXX simplify -> errgroup func (e *_HandshakeError) Unwrap() error { return e.Err }
errch := make(chan error, 2)
// handshakeClient implements client-side handshake.
// tx handshake word //
txWg := sync.WaitGroup{} // Client indicates its version and preferred encoding, but accepts any
txWg.Add(1) // encoding choosen to use by server.
go func() { func handshakeClient(ctx context.Context, conn net.Conn, version uint32, encPrefer proto.Encoding) (enc proto.Encoding, rxbuf *fwd.Reader, err error) {
var b [4]byte defer func() {
binary.BigEndian.PutUint32(b[:], version) // XXX -> hton32 ? if err != nil {
_, err := conn.Write(b[:]) err = &_HandshakeError{_LinkClient, conn.LocalAddr(), conn.RemoteAddr(), err}
// XXX EOF -> ErrUnexpectedEOF ? }
errch <- err
txWg.Done()
}() }()
// rx handshake word rxbuf = fwd.NewReader(conn)
go func() {
var b [4]byte var peerEnc proto.Encoding
_, err := io.ReadFull(conn, b[:]) err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error {
err = xio.NoEOF(err) // can be returned with n = 0 // tx client hello
if err == nil { err := txHello("tx hello", conn, version, encPrefer)
peerVersion := binary.BigEndian.Uint32(b[:]) // XXX -> ntoh32 ? if err != nil {
if peerVersion != version { return err
err = fmt.Errorf("protocol version mismatch: peer = %08x ; our side = %08x", peerVersion, version)
} }
// rx server hello reply
var peerVer uint32
peerEnc, peerVer, err = rxHello("rx hello reply", rxbuf)
if err != nil {
return err
} }
errch <- err
}()
connClosed := false // verify version
defer func() { if peerVer != version {
// make sure our version is always sent on the wire, if possible, return fmt.Errorf("protocol version mismatch: peer = %08x ; our side = %08x", peerVer, version)
// so that peer does not see just closed connection when on rx we see version mismatch. }
//
// NOTE if cancelled tx goroutine will wake up without delay.
txWg.Wait()
// don't forget to close conn if returning with error + add handshake err context return nil
})
if err != nil { if err != nil {
err = &_HandshakeError{conn.LocalAddr(), conn.RemoteAddr(), err} return 0, nil, err
if !connClosed {
conn.Close()
} }
// use peer encoding (server should return the same, but we are ok if
// it asks to switch to different)
return peerEnc, rxbuf, nil
}
// handshakeServer implementss server-side handshake.
//
// Server verifies that its version matches Client and accepts client preferred encoding.
func handshakeServer(ctx context.Context, conn net.Conn, version uint32) (enc proto.Encoding, rxbuf *fwd.Reader, err error) {
defer func() {
if err != nil {
err = &_HandshakeError{_LinkServer, conn.LocalAddr(), conn.RemoteAddr(), err}
} }
}() }()
for i := 0; i < 2; i++ { rxbuf = fwd.NewReader(conn)
select {
case <-ctx.Done():
conn.Close() // interrupt IO
connClosed = true
return ctx.Err()
case err = <-errch: var peerEnc proto.Encoding
err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error {
// rx client hello
var peerVer uint32
var err error
peerEnc, peerVer, err = rxHello("rx hello", rxbuf)
if err != nil { if err != nil {
return err return err
} }
// tx server reply
//
// do it before version check so that client can also detect "version
// mismatch" instead of just getting "disconnect".
err = txHello("tx hello reply", conn, version, peerEnc)
if err != nil {
return err
} }
// verify version
if peerVer != version {
return fmt.Errorf("protocol version mismatch: peer = %08x ; our side = %08x", peerVer, version)
} }
// handshaked ok
return nil return nil
})
if err != nil {
return 0, nil, err
}
return peerEnc, rxbuf, nil
}
func txHello(errctx string, conn net.Conn, version uint32, enc proto.Encoding) (err error) {
defer xerr.Context(&err, errctx)
var b []byte
switch enc {
case 'N':
// 00 00 00 <v>
b = make([]byte, 4)
if version > 0xff {
panic("encoding N supports versions only in range [0, 0xff]")
}
b[3] = uint8(version)
case 'M':
// (b"NEO", <V>) encoded as msgpack (= 92 c4 03 NEO int(<V>))
b = msgp.AppendArrayHeader(b, 2) // 92
b = msgp.AppendBytes(b, []byte("NEO")) // c4 03 NEO
b = msgp.AppendUint32(b, version) // u?intX version
default:
panic("bug")
}
_, err = conn.Write(b)
if err != nil {
return err
}
return nil
}
func rxHello(errctx string, rx *fwd.Reader) (enc proto.Encoding, version uint32, err error) {
defer xerr.Context(&err, errctx)
b := make([]byte, 4)
_, err = io.ReadFull(rx, b)
err = xio.NoEOF(err)
if err != nil {
return 0, 0, err
}
var peerEnc proto.Encoding
var peerVer uint32
badMagic := false
switch {
case bytes.Equal(b[:3], []byte{0,0,0}):
peerEnc = encN
peerVer = uint32(b[3])
case bytes.Equal(b, []byte{0x92, 0xc4, 3, 'N'}): // start of "fixarray<2> bin8 'N | EO' ...
b = append(b, []byte{0,0}...)
_, err = io.ReadFull(rx, b[4:])
err = xio.NoEOF(err)
if err != nil {
return 0, 0, err
}
if !bytes.Equal(b[4:], []byte{'E','O'}) {
badMagic = true
break
}
peerEnc = encM
rxM := msgp.Reader{R: rx}
peerVer, err = rxM.ReadUint32()
if err != nil {
return 0, 0, fmt.Errorf("M: recv peer version: %s", err) // XXX + "read magic" ctx
}
default:
badMagic = true
}
if badMagic {
return 0, 0, fmt.Errorf("invalid magic %x", b)
}
return peerEnc, peerVer, nil
} }
...@@ -141,6 +285,8 @@ func DialLink(ctx context.Context, net xnet.Networker, addr string) (*NodeLink, ...@@ -141,6 +285,8 @@ func DialLink(ctx context.Context, net xnet.Networker, addr string) (*NodeLink,
return nil, err return nil, err
} }
// TODO if handshake fails with "closed" (= might be unexpected encoding)
// -> try redial and handshaking with different encoding (= autodetect encoding)
return _Handshake(ctx, peerConn, _LinkClient) return _Handshake(ctx, peerConn, _LinkClient)
} }
......
...@@ -21,29 +21,47 @@ package neonet ...@@ -21,29 +21,47 @@ package neonet
import ( import (
"context" "context"
"errors"
"io" "io"
"net" "net"
"testing" "testing"
"lab.nexedi.com/kirr/go123/exc" "lab.nexedi.com/kirr/go123/exc"
"lab.nexedi.com/kirr/go123/xsync" "lab.nexedi.com/kirr/go123/xsync"
"lab.nexedi.com/kirr/neo/go/neo/proto"
) )
func xhandshake(ctx context.Context, c net.Conn, version uint32) { // xhandshakeClient handshakes as client with encPrefer encoding and verifies that server accepts it.
err := handshake(ctx, c, version) func xhandshakeClient(ctx context.Context, c net.Conn, version uint32, encPrefer proto.Encoding) {
enc, _, err := handshakeClient(ctx, c, version, encPrefer)
exc.Raiseif(err) exc.Raiseif(err)
if enc != encPrefer {
exc.Raisef("enc (%c) != encPrefer (%c)", enc, encPrefer)
}
}
// xhandshakeServer handshakes as server and verifies negotiated encoding to be encOK.
func xhandshakeServer(ctx context.Context, c net.Conn, version uint32, encOK proto.Encoding) {
enc, _, err := handshakeServer(ctx, c, version)
exc.Raiseif(err)
if enc != encOK {
exc.Raisef("enc (%c) != encOK (%c)", enc, encOK)
}
} }
func TestHandshake(t *testing.T) { func TestHandshake(t *testing.T) {
Verify(t, _TestHandshake)
}
func _TestHandshake(t *T) {
bg := context.Background() bg := context.Background()
// handshake ok // handshake ok
p1, p2 := net.Pipe() p1, p2 := net.Pipe()
wg := xsync.NewWorkGroup(bg) wg := xsync.NewWorkGroup(bg)
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
xhandshake(ctx, p1, 1) xhandshakeClient(ctx, p1, 1, t.enc)
}) })
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
xhandshake(ctx, p2, 1) xhandshakeServer(ctx, p2, 1, t.enc)
}) })
xwait(wg) xwait(wg)
xclose(p1) xclose(p1)
...@@ -54,17 +72,17 @@ func TestHandshake(t *testing.T) { ...@@ -54,17 +72,17 @@ func TestHandshake(t *testing.T) {
var err1, err2 error var err1, err2 error
wg = xsync.NewWorkGroup(bg) wg = xsync.NewWorkGroup(bg)
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
err1 = handshake(ctx, p1, 1) _, _, err1 = handshakeClient(ctx, p1, 1, t.enc)
}) })
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
err2 = handshake(ctx, p2, 2) _, _, err2 = handshakeServer(ctx, p2, 2)
}) })
xwait(wg) xwait(wg)
xclose(p1) xclose(p1)
xclose(p2) xclose(p2)
err1Want := "pipe - pipe: handshake: protocol version mismatch: peer = 00000002 ; our side = 00000001" err1Want := "pipe - pipe: handshake (client): protocol version mismatch: peer = 00000002 ; our side = 00000001"
err2Want := "pipe - pipe: handshake: protocol version mismatch: peer = 00000001 ; our side = 00000002" err2Want := "pipe - pipe: handshake (server): protocol version mismatch: peer = 00000001 ; our side = 00000002"
if !(err1 != nil && err1.Error() == err1Want) { if !(err1 != nil && err1.Error() == err1Want) {
t.Errorf("handshake ver mismatch: p1: unexpected error:\nhave: %v\nwant: %v", err1, err1Want) t.Errorf("handshake ver mismatch: p1: unexpected error:\nhave: %v\nwant: %v", err1, err1Want)
...@@ -78,7 +96,7 @@ func TestHandshake(t *testing.T) { ...@@ -78,7 +96,7 @@ func TestHandshake(t *testing.T) {
err1, err2 = nil, nil err1, err2 = nil, nil
wg = xsync.NewWorkGroup(bg) wg = xsync.NewWorkGroup(bg)
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
err1 = handshake(ctx, p1, 1) _, _, err1 = handshakeClient(ctx, p1, 1, t.enc)
}) })
gox(wg, func(_ context.Context) { gox(wg, func(_ context.Context) {
xclose(p2) xclose(p2)
...@@ -88,16 +106,20 @@ func TestHandshake(t *testing.T) { ...@@ -88,16 +106,20 @@ func TestHandshake(t *testing.T) {
err11, ok := err1.(*_HandshakeError) err11, ok := err1.(*_HandshakeError)
if !ok || !(err11.Err == io.ErrClosedPipe /* on Write */ || err11.Err == io.ErrUnexpectedEOF /* on Read */) { if !ok || !(errors.Is(err11.Err, io.ErrClosedPipe /* on Write */) || errors.Is(err11.Err, io.ErrUnexpectedEOF /* on Read */)) {
t.Errorf("handshake peer close: unexpected error: %#v", err1) t.Errorf("handshake peer close: unexpected error: %#v", err1)
} }
// XXX same for handshakeServer
// ctx cancel // ctx cancel
// XXX same for handshakeServer
p1, p2 = net.Pipe() p1, p2 = net.Pipe()
ctx, cancel := context.WithCancel(bg) ctx, cancel := context.WithCancel(bg)
wg = xsync.NewWorkGroup(ctx) wg = xsync.NewWorkGroup(ctx)
gox(wg, func(ctx context.Context) { gox(wg, func(ctx context.Context) {
err1 = handshake(ctx, p1, 1) _, _, err1 = handshakeClient(ctx, p1, 1, t.enc)
}) })
tdelay() tdelay()
cancel() cancel()
...@@ -110,5 +132,4 @@ func TestHandshake(t *testing.T) { ...@@ -110,5 +132,4 @@ func TestHandshake(t *testing.T) {
if !ok || !(err11.Err == context.Canceled) { if !ok || !(err11.Err == context.Canceled) {
t.Errorf("handshake cancel: unexpected error: %#v", err1) t.Errorf("handshake cancel: unexpected error: %#v", err1)
} }
} }
...@@ -39,15 +39,17 @@ type pktBuf struct { ...@@ -39,15 +39,17 @@ type pktBuf struct {
data []byte // whole packet data including all headers data []byte // whole packet data including all headers
} }
// Header returns pointer to packet header. // HeaderN returns pointer to packet header in 'N'-encoding.
func (pkt *pktBuf) Header() *proto.PktHeader { func (pkt *pktBuf) Header() *proto.PktHeader { return pkt.HeaderN() } // XXX kill
func (pkt *pktBuf) HeaderN() *proto.PktHeader {
// NOTE no need to check len(.data) < PktHeader: // NOTE no need to check len(.data) < PktHeader:
// .data is always allocated with cap >= PktHeaderLen. // .data is always allocated with cap >= PktHeaderLen.
return (*proto.PktHeader)(unsafe.Pointer(&pkt.data[0])) return (*proto.PktHeader)(unsafe.Pointer(&pkt.data[0]))
} }
// Payload returns []byte representing packet payload. // PayloadN returns []byte representing packet payload in 'N'-encoding.
func (pkt *pktBuf) Payload() []byte { func (pkt *pktBuf) Payload() []byte { return pkt.PayloadN() } // XXX kill
func (pkt *pktBuf) PayloadN() []byte {
return pkt.data[proto.PktHeaderLen:] return pkt.data[proto.PktHeaderLen:]
} }
...@@ -87,6 +89,7 @@ func (pkt *pktBuf) String() string { ...@@ -87,6 +89,7 @@ func (pkt *pktBuf) String() string {
h := pkt.Header() h := pkt.Header()
s := fmt.Sprintf(".%d", packed.Ntoh32(h.ConnId)) s := fmt.Sprintf(".%d", packed.Ntoh32(h.ConnId))
// XXX encN-specific
msgCode := packed.Ntoh16(h.MsgCode) msgCode := packed.Ntoh16(h.MsgCode)
msgLen := packed.Ntoh32(h.MsgLen) msgLen := packed.Ntoh32(h.MsgLen)
data := pkt.Payload() data := pkt.Payload()
...@@ -98,7 +101,7 @@ func (pkt *pktBuf) String() string { ...@@ -98,7 +101,7 @@ func (pkt *pktBuf) String() string {
// XXX dup wrt Conn.Recv // XXX dup wrt Conn.Recv
msg := reflect.New(msgType).Interface().(proto.Msg) msg := reflect.New(msgType).Interface().(proto.Msg)
n, err := msg.NEOMsgDecode(data) n, err := encN.NEOMsgDecode(msg, data) // XXX encN hardcoded
if err != nil { if err != nil {
s += fmt.Sprintf(" (%s) %v; #%d [%d]: % x", msgType.Name(), err, msgLen, len(data), data) s += fmt.Sprintf(" (%s) %v; #%d [%d]: % x", msgType.Name(), err, msgLen, len(data), data)
} else { } else {
......
// Copyright (C) 2020-2021 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
// it under the terms of the GNU General Public License version 3, or (at your
// option) any later version, as published by the Free Software Foundation.
//
// You can also Link and Combine this program with other software covered by
// the terms of any of the Free Software licenses or any of the Open Source
// Initiative approved licenses and Convey the resulting work. Corresponding
// source of such a combination shall include the source code for all other
// software used.
//
// This program is distributed WITHOUT ANY WARRANTY; without even the implied
// warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
//
// See COPYING file for full licensing terms.
// See https://www.nexedi.com/licensing for rationale and options.
package proto
// runtime glue for msgpack support
import (
"fmt"
"github.com/tinylib/msgp/msgp"
"lab.nexedi.com/kirr/neo/go/neo/internal/msgpack"
)
// mstructDecodeError represents decode error when decoder was expecting
// tuple<nfield> for structure named path.
type mstructDecodeError struct {
path string // "Type.field.field"
op msgpack.Op // op we got
opOk msgpack.Op // op expected
}
func (e *mstructDecodeError) Error() string {
return fmt.Sprintf("decode: M: struct %s: got opcode %02x; expect %02x", e.path, e.op, e.opOk)
}
// mdecodeErr is called to normilize error when msgp.ReadXXX returns err when decoding path.
func mdecodeErr(path string, err error) error {
if err == msgp.ErrShortBytes {
return ErrDecodeOverflow
}
return &mdecodeError{path, err}
}
type mdecodeError struct {
path string // "Type.field.field"
err error
}
func (e *mdecodeError) Error() string {
return fmt.Sprintf("decode: M: %s: %s", e.path, e.err)
}
// mOpError represents decode error when decoder faces unexpected operation.
type mOpError struct {
op, opOk msgpack.Op // op we got and what was expected
}
func (e *mOpError) Error() string {
return fmt.Sprintf("expected opcode %02x; got %02x", e.opOk, e.op)
}
func mdecodeOpErr(path string, op, opOk msgpack.Op) error {
return mdecodeErr(path+"/op", &mOpError{op, opOk})
}
// mLen8Error represents decode error when decoder faces unexpected length in Bin8.
type mLen8Error struct {
l, lOk byte // len we got and expected
}
func (e *mLen8Error) Error() string {
return fmt.Sprintf("expected length %d; got %d", e.lOk, e.l)
}
func mdecodeLen8Err(path string, l, lOk uint8) error {
return mdecodeErr(path+"/len", &mLen8Error{l, lOk})
}
func mdecodeEnumTypeErr(path string, enumType, enumTypeOk byte) error {
return mdecodeErr(path+"/enumType",
fmt.Errorf("expected %d; got %d", enumTypeOk, enumType))
}
func mdecodeEnumValueErr(path string, v byte) error {
return mdecodeErr(path, fmt.Errorf("invalid enum payload %02x", v))
}
...@@ -32,6 +32,12 @@ import ( ...@@ -32,6 +32,12 @@ import (
"time" "time"
) )
// MsgCode returns code the corresponds to type of the message.
// XXX place - ok?
func MsgCode(msg Msg) uint16 {
return msg.neoMsgCode()
}
// MsgType looks up message type by message code. // MsgType looks up message type by message code.
// //
// Nil is returned if message code is not valid. // Nil is returned if message code is not valid.
......
...@@ -84,6 +84,7 @@ const ( ...@@ -84,6 +84,7 @@ const (
Version = 6 Version = 6
// length of packet header // length of packet header
// XXX encN-specific ?
PktHeaderLen = 10 // = unsafe.Sizeof(PktHeader{}), but latter gives typed constant (uintptr) PktHeaderLen = 10 // = unsafe.Sizeof(PktHeader{}), but latter gives typed constant (uintptr)
// packets larger than PktMaxSize are not allowed. // packets larger than PktMaxSize are not allowed.
...@@ -99,6 +100,7 @@ const ( ...@@ -99,6 +100,7 @@ const (
INVALID_OID zodb.Oid = 1<<64 - 1 INVALID_OID zodb.Oid = 1<<64 - 1
) )
// XXX encN-specific ?
// PktHeader represents header of a raw packet. // PktHeader represents header of a raw packet.
// //
// A packet contains connection ID and message. // A packet contains connection ID and message.
...@@ -110,31 +112,75 @@ type PktHeader struct { ...@@ -110,31 +112,75 @@ type PktHeader struct {
MsgLen packed.BE32 // payload message length (excluding packet header) MsgLen packed.BE32 // payload message length (excluding packet header)
} }
// Msg is the interface implemented by all NEO messages. // Msg is the interface representing a NEO message.
type Msg interface { type Msg interface {
// marshal/unmarshal into/from wire format: // marshal/unmarshal into/from wire format:
// NEOMsgCode returns message code needed to be used for particular message type // neoMsgCode returns message code needed to be used for particular message type
// on the wire. // on the wire.
NEOMsgCode() uint16 neoMsgCode() uint16
// NEOMsgEncodedLen returns how much space is needed to encode current message payload. // for encoding E:
NEOMsgEncodedLen() int //
// - neoMsgEncodedLen<E> returns how much space is needed to encode current message payload via E encoding.
// NEOMsgEncode encodes current message state into buf. //
// - neoMsgEncode<E> encodes current message state into buf via E encoding.
// //
// len(buf) must be >= neoMsgEncodedLen(). // len(buf) must be >= neoMsgEncodedLen<E>().
NEOMsgEncode(buf []byte) //
// - neoMsgDecode<E> decodes data via E encoding into message in-place.
// N encoding (original struct-based encoding)
neoMsgEncodedLenN() int
neoMsgEncodeN(buf []byte)
neoMsgDecodeN(data []byte) (nread int, err error)
// NEOMsgDecode decodes data into message in-place. // M encoding (via MessagePack)
NEOMsgDecode(data []byte) (nread int, err error) neoMsgEncodedLenM() int
neoMsgEncodeM(buf []byte)
neoMsgDecodeM(data []byte) (nread int, err error)
}
// Encoding represents messages encoding.
type Encoding byte
// XXX drop "NEO" prefix?
// NEOMsgEncodedLen returns how much space is needed to encode msg payload via encoding e.
func (e Encoding) NEOMsgEncodedLen(msg Msg) int {
switch e {
default: panic("bug")
case 'N': return msg.neoMsgEncodedLenN()
case 'M': return msg.neoMsgEncodedLenM()
}
} }
// NEOMsgEncode encodes msg state into buf via encoding e.
//
// len(buf) must be >= e.NEOMsgEncodedLen(m).
func (e Encoding) NEOMsgEncode(msg Msg, buf []byte) {
switch e {
default: panic("bug")
case 'N': msg.neoMsgEncodeN(buf)
case 'M': msg.neoMsgEncodeM(buf)
}
}
// NEOMsgDecode decodes data via encoding e into msg in-place.
func (e Encoding) NEOMsgDecode(msg Msg, data []byte) (nread int, err error) {
switch e {
default: panic("bug")
case 'N': return msg.neoMsgDecodeN(data)
case 'M': return msg.neoMsgDecodeM(data)
}
}
// ErrDecodeOverflow is the error returned by neoMsgDecode when decoding hits buffer overflow // ErrDecodeOverflow is the error returned by neoMsgDecode when decoding hits buffer overflow
var ErrDecodeOverflow = errors.New("decode: buffer overflow") var ErrDecodeOverflow = errors.New("decode: buffer overflow")
// ---- messages ---- // ---- messages ----
//neo:proto enum
type ErrorCode uint32 type ErrorCode uint32
const ( const (
ACK ErrorCode = iota ACK ErrorCode = iota
...@@ -155,6 +201,7 @@ const ( ...@@ -155,6 +201,7 @@ const (
// XXX move this to neo.clusterState wrapping proto.ClusterState? // XXX move this to neo.clusterState wrapping proto.ClusterState?
//trace:event traceClusterStateChanged(cs *ClusterState) //trace:event traceClusterStateChanged(cs *ClusterState)
//neo:proto enum
type ClusterState int8 type ClusterState int8
const ( const (
// The cluster is initially in the RECOVERING state, and it goes back to // The cluster is initially in the RECOVERING state, and it goes back to
...@@ -188,6 +235,7 @@ const ( ...@@ -188,6 +235,7 @@ const (
STOPPING_BACKUP STOPPING_BACKUP
) )
//neo:proto enum
type NodeType int8 type NodeType int8
const ( const (
MASTER NodeType = iota MASTER NodeType = iota
...@@ -196,6 +244,7 @@ const ( ...@@ -196,6 +244,7 @@ const (
ADMIN ADMIN
) )
//neo:proto enum
type NodeState int8 type NodeState int8
const ( const (
UNKNOWN NodeState = iota //short: U // XXX tag prefix name ? UNKNOWN NodeState = iota //short: U // XXX tag prefix name ?
...@@ -204,6 +253,7 @@ const ( ...@@ -204,6 +253,7 @@ const (
PENDING //short: P PENDING //short: P
) )
//neo:proto enum
type CellState int8 type CellState int8
const ( const (
// Write-only cell. Last transactions are missing because storage is/was down // Write-only cell. Last transactions are missing because storage is/was down
...@@ -255,7 +305,7 @@ type Address struct { ...@@ -255,7 +305,7 @@ type Address struct {
} }
// NOTE if Host == "" -> Port not added to wire (see py.PAddress): // NOTE if Host == "" -> Port not added to wire (see py.PAddress):
func (a *Address) neoEncodedLen() int { func (a *Address) neoEncodedLenN() int {
l := string_neoEncodedLen(a.Host) l := string_neoEncodedLen(a.Host)
if a.Host != "" { if a.Host != "" {
l += 2 l += 2
...@@ -263,7 +313,7 @@ func (a *Address) neoEncodedLen() int { ...@@ -263,7 +313,7 @@ func (a *Address) neoEncodedLen() int {
return l return l
} }
func (a *Address) neoEncode(b []byte) int { func (a *Address) neoEncodeN(b []byte) int {
n := string_neoEncode(a.Host, b[0:]) n := string_neoEncode(a.Host, b[0:])
if a.Host != "" { if a.Host != "" {
binary.BigEndian.PutUint16(b[n:], a.Port) binary.BigEndian.PutUint16(b[n:], a.Port)
...@@ -272,7 +322,7 @@ func (a *Address) neoEncode(b []byte) int { ...@@ -272,7 +322,7 @@ func (a *Address) neoEncode(b []byte) int {
return n return n
} }
func (a *Address) neoDecode(b []byte) (uint64, bool) { func (a *Address) neoDecodeN(b []byte) (uint64, bool) {
n, ok := string_neoDecode(&a.Host, b) n, ok := string_neoDecode(&a.Host, b)
if !ok { if !ok {
return 0, false return 0, false
...@@ -295,17 +345,17 @@ type Checksum [20]byte ...@@ -295,17 +345,17 @@ type Checksum [20]byte
// PTid is Partition Table identifier. // PTid is Partition Table identifier.
// //
// Zero value means "invalid id" (<-> None in py.PPTID) // Zero value means "invalid id" (<-> None in py.PPTID) XXX = nil in msgpack
type PTid uint64 type PTid uint64
// IdTime represents time of identification. // IdTime represents time of identification.
type IdTime float64 type IdTime float64
func (t IdTime) neoEncodedLen() int { func (t IdTime) neoEncodedLenN() int {
return 8 return 8
} }
func (t IdTime) neoEncode(b []byte) int { func (t IdTime) neoEncodeN(b []byte) int {
// use -inf as value for no data (NaN != NaN -> hard to use NaN in tests) // use -inf as value for no data (NaN != NaN -> hard to use NaN in tests)
// NOTE neo/py uses None for "no data"; we use 0 for "no data" to avoid pointer // NOTE neo/py uses None for "no data"; we use 0 for "no data" to avoid pointer
tt := float64(t) tt := float64(t)
...@@ -316,7 +366,7 @@ func (t IdTime) neoEncode(b []byte) int { ...@@ -316,7 +366,7 @@ func (t IdTime) neoEncode(b []byte) int {
return 8 return 8
} }
func (t *IdTime) neoDecode(data []byte) (uint64, bool) { func (t *IdTime) neoDecodeN(data []byte) (uint64, bool) {
if len(data) < 8 { if len(data) < 8 {
return 0, false return 0, false
} }
...@@ -438,8 +488,8 @@ type Recovery struct { ...@@ -438,8 +488,8 @@ type Recovery struct {
type AnswerRecovery struct { type AnswerRecovery struct {
PTid PTid
BackupTid zodb.Tid BackupTid zodb.Tid // XXX nil <-> 0
TruncateTid zodb.Tid TruncateTid zodb.Tid // XXX nil <-> 0
} }
// Ask the last OID/TID so that a master can initialize its TransactionManager. // Ask the last OID/TID so that a master can initialize its TransactionManager.
...@@ -1199,13 +1249,13 @@ type FlushLog struct {} ...@@ -1199,13 +1249,13 @@ type FlushLog struct {}
// ---- runtime support for protogen and custom codecs ---- // ---- runtime support for protogen and custom codecs ----
// customCodec is the interface that is implemented by types with custom encodings. // customCodecN is the interface that is implemented by types with custom N encodings.
// //
// its semantic is very similar to Msg. // its semantic is very similar to Msg.
type customCodec interface { type customCodecN interface {
neoEncodedLen() int neoEncodedLenN() int
neoEncode(buf []byte) (nwrote int) neoEncodeN(buf []byte) (nwrote int)
neoDecode(data []byte) (nread uint64, ok bool) // XXX uint64 or int here? neoDecodeN(data []byte) (nread uint64, ok bool) // XXX uint64 or int here?
} }
func byte2bool(b byte) bool { func byte2bool(b byte) bool {
......
...@@ -79,31 +79,32 @@ func TestPktHeader(t *testing.T) { ...@@ -79,31 +79,32 @@ func TestPktHeader(t *testing.T) {
} }
// test marshalling for one message type // test marshalling for one message type
func testMsgMarshal(t *testing.T, msg Msg, encoded string) { func testMsgMarshal(t *testing.T, enc Encoding, msg Msg, encoded string) {
typ := reflect.TypeOf(msg).Elem() // type of *msg typ := reflect.TypeOf(msg).Elem() // type of *msg
msg2 := reflect.New(typ).Interface().(Msg) msg2 := reflect.New(typ).Interface().(Msg)
defer func() { defer func() {
if e := recover(); e != nil { if e := recover(); e != nil {
t.Errorf("%v: panic ↓↓↓:", typ) t.Errorf("%c/%v: panic ↓↓↓:", enc, typ)
panic(e) // to show traceback panic(e) // to show traceback
} }
}() }()
// msg.encode() == expected // msg.encode() == expected
msgCode := msg.NEOMsgCode() msgCode := msg.neoMsgCode()
n := msg.NEOMsgEncodedLen() n := enc.NEOMsgEncodedLen(msg)
msgType := MsgType(msgCode) msgType := MsgType(msgCode)
if msgType != typ { if msgType != typ {
t.Errorf("%v: msgCode = %v which corresponds to %v", typ, msgCode, msgType) t.Errorf("%c/%v: msgCode = %v which corresponds to %v", enc, typ, msgCode, msgType)
} }
if n != len(encoded) { if n != len(encoded) {
t.Errorf("%v: encodedLen = %v ; want %v", typ, n, len(encoded)) t.Errorf("%c/%v: encodedLen = %v ; want %v", enc, typ, n, len(encoded))
} }
buf := make([]byte, n) buf := make([]byte, n)
msg.NEOMsgEncode(buf) enc.NEOMsgEncode(msg, buf)
if string(buf) != encoded { if string(buf) != encoded {
t.Errorf("%v: encode result unexpected:", typ) t.Errorf("%c/%v: encode result unexpected:", enc, typ)
t.Errorf("\thave: %s", hexpkg.EncodeToString(buf)) t.Errorf("\thave: %s", hexpkg.EncodeToString(buf))
t.Errorf("\twant: %s", hexpkg.EncodeToString([]byte(encoded))) t.Errorf("\twant: %s", hexpkg.EncodeToString([]byte(encoded)))
} }
...@@ -112,7 +113,7 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) { ...@@ -112,7 +113,7 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
for l := len(buf) - 1; l >= 0; l-- { for l := len(buf) - 1; l >= 0; l-- {
func() { func() {
defer func() { defer func() {
subj := fmt.Sprintf("%v: encode(buf[:encodedLen-%v])", typ, len(encoded)-l) subj := fmt.Sprintf("%c/%v: encode(buf[:encodedLen-%v])", enc, typ, len(encoded)-l)
e := recover() e := recover()
if e == nil { if e == nil {
t.Errorf("%s did not panic", subj) t.Errorf("%s did not panic", subj)
...@@ -131,29 +132,29 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) { ...@@ -131,29 +132,29 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
} }
}() }()
msg.NEOMsgEncode(buf[:l]) enc.NEOMsgEncode(msg, buf[:l])
}() }()
} }
// msg.decode() == expected // msg.decode() == expected
data := []byte(encoded + "noise") data := []byte(encoded + "noise")
n, err := msg2.NEOMsgDecode(data) n, err := enc.NEOMsgDecode(msg2, data)
if err != nil { if err != nil {
t.Errorf("%v: decode error %v", typ, err) t.Errorf("%c/%v: decode error %v", enc, typ, err)
} }
if n != len(encoded) { if n != len(encoded) {
t.Errorf("%v: nread = %v ; want %v", typ, n, len(encoded)) t.Errorf("%c/%v: nread = %v ; want %v", enc, typ, n, len(encoded))
} }
if !reflect.DeepEqual(msg2, msg) { if !reflect.DeepEqual(msg2, msg) {
t.Errorf("%v: decode result unexpected: %v ; want %v", typ, msg2, msg) t.Errorf("%c/%v: decode result unexpected: %v ; want %v", enc, typ, msg2, msg)
} }
// decode must detect buffer overflow // decode must detect buffer overflow
for l := len(encoded) - 1; l >= 0; l-- { for l := len(encoded) - 1; l >= 0; l-- {
n, err = msg2.NEOMsgDecode(data[:l]) n, err = enc.NEOMsgDecode(msg2, data[:l])
if !(n == 0 && err == ErrDecodeOverflow) { if !(n == 0 && err == ErrDecodeOverflow) {
t.Errorf("%v: decode overflow not detected on [:%v]", typ, l) t.Errorf("%c/%v: decode overflow not detected on [:%v]", enc, typ, l)
} }
} }
...@@ -163,13 +164,20 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) { ...@@ -163,13 +164,20 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
func TestMsgMarshal(t *testing.T) { func TestMsgMarshal(t *testing.T) {
var testv = []struct { var testv = []struct {
msg Msg msg Msg
encoded string // []byte encodedN string // []byte
encodedM string // []byte
}{ }{
// empty // empty
{&Ping{}, ""}, {&Ping{},
"",
"\x90",
},
// uint32, string // uint32(N)/enum(M), string
{&Error{Code: 0x01020304, Message: "hello"}, "\x01\x02\x03\x04\x00\x00\x00\x05hello"}, {&Error{Code: 0x00000045, Message: "hello"},
"\x00\x00\x00\x45\x00\x00\x00\x05hello",
hex("92") + hex("d40045") + "\xc4\x05hello",
},
// Oid, Tid, bool, Checksum, []byte // Oid, Tid, bool, Checksum, []byte
{&StoreObject{ {&StoreObject{
...@@ -185,7 +193,18 @@ func TestMsgMarshal(t *testing.T) { ...@@ -185,7 +193,18 @@ func TestMsgMarshal(t *testing.T) {
hex("01020304050607080a0b0c0d0e0f010200") + hex("01020304050607080a0b0c0d0e0f010200") +
hex("0102030405060708090a0b0c0d0e0f1011121314") + hex("0102030405060708090a0b0c0d0e0f1011121314") +
hex("0000000b") + "hello world" + hex("0000000b") + "hello world" +
hex("0a0b0c0d0e0f01030a0b0c0d0e0f0104")}, hex("0a0b0c0d0e0f01030a0b0c0d0e0f0104"),
// M
hex("97") +
hex("c408") + hex("0102030405060708") +
hex("c408") + hex("0a0b0c0d0e0f0102") +
hex("c2") +
hex("c414") + hex("0102030405060708090a0b0c0d0e0f1011121314") +
hex("c40b") + "hello world" +
hex("c408") + hex("0a0b0c0d0e0f0103") +
hex("c408") + hex("0a0b0c0d0e0f0104"),
},
// PTid, [] (of [] of {UUID, CellState}) // PTid, [] (of [] of {UUID, CellState})
{&AnswerPartitionTable{ {&AnswerPartitionTable{
...@@ -198,12 +217,22 @@ func TestMsgMarshal(t *testing.T) { ...@@ -198,12 +217,22 @@ func TestMsgMarshal(t *testing.T) {
}, },
}, },
// N
hex("0102030405060708") + hex("0102030405060708") +
hex("00000022") + hex("00000022") +
hex("00000003") + hex("00000003") +
hex("000000020000000b010000001100") + hex("000000020000000b010000001100") +
hex("000000010000000b02") + hex("000000010000000b02") +
hex("000000030000000b030000000f040000001701"), hex("000000030000000b030000000f040000001701"),
// M
hex("93") +
hex("cf0102030405060708") +
hex("22") +
hex("93") +
hex("91"+"92"+"920bd40401"+"9211d40400") +
hex("91"+"91"+"920bd40402") +
hex("91"+"93"+"920bd40403"+"920fd40404"+"9217d40401"),
}, },
// map[Oid]struct {Tid,Tid,bool} // map[Oid]struct {Tid,Tid,bool}
...@@ -219,11 +248,20 @@ func TestMsgMarshal(t *testing.T) { ...@@ -219,11 +248,20 @@ func TestMsgMarshal(t *testing.T) {
5: {4, 3, true}, 5: {4, 3, true},
}}, }},
// N
u32(4) + u32(4) +
u64(1) + u64(1) + u64(0) + hex("00") + u64(1) + u64(1) + u64(0) + hex("00") +
u64(2) + u64(7) + u64(1) + hex("01") + u64(2) + u64(7) + u64(1) + hex("01") +
u64(5) + u64(4) + u64(3) + hex("01") + u64(5) + u64(4) + u64(3) + hex("01") +
u64(8) + u64(7) + u64(1) + hex("00"), u64(8) + u64(7) + u64(1) + hex("00"),
// M
hex("91") +
hex("84") +
hex("c408")+u64(1) + hex("93") + hex("c408")+u64(1) + hex("c408")+u64(0) + hex("c2") +
hex("c408")+u64(2) + hex("93") + hex("c408")+u64(7) + hex("c408")+u64(1) + hex("c3") +
hex("c408")+u64(5) + hex("93") + hex("c408")+u64(4) + hex("c408")+u64(3) + hex("c3") +
hex("c408")+u64(8) + hex("93") + hex("c408")+u64(7) + hex("c408")+u64(1) + hex("c2"),
}, },
// map[uint32]UUID + trailing ... // map[uint32]UUID + trailing ...
...@@ -238,41 +276,86 @@ func TestMsgMarshal(t *testing.T) { ...@@ -238,41 +276,86 @@ func TestMsgMarshal(t *testing.T) {
MaxTID: 128, MaxTID: 128,
}, },
// N
u32(4) + u32(4) +
u32(1) + u32(7) + u32(1) + u32(7) +
u32(2) + u32(9) + u32(2) + u32(9) +
u32(4) + u32(17) + u32(4) + u32(17) +
u32(7) + u32(3) + u32(7) + u32(3) +
u64(23) + u64(128), u64(23) + u64(128),
// M
hex("93") +
hex("84") +
hex("01" + "07") +
hex("02" + "09") +
hex("04" + "11") +
hex("07" + "03") +
hex("c408") + u64(23) +
hex("c408") + u64(128),
}, },
// uint32, []uint32 // uint32, []uint32
{&PartitionCorrupted{7, []NodeUUID{1, 3, 9, 4}}, {&PartitionCorrupted{7, []NodeUUID{1, 3, 9, 4}},
// N
u32(7) + u32(4) + u32(1) + u32(3) + u32(9) + u32(4), u32(7) + u32(4) + u32(1) + u32(3) + u32(9) + u32(4),
// M
hex("92") +
hex("07") +
hex("94") +
hex("01030904"),
}, },
// uint32, Address, string, IdTime // uint32, Address, string, IdTime
{&RequestIdentification{CLIENT, 17, Address{"localhost", 7777}, "myname", 0.12345678, []string{"room1", "rack234"}, []uint32{3,4,5} }, {&RequestIdentification{CLIENT, 17, Address{"localhost", 7777}, "myname", 0.12345678, []string{"room1", "rack234"}, []uint32{3,4,5} },
// N
u8(2) + u32(17) + u32(9) + u8(2) + u32(17) + u32(9) +
"localhost" + u16(7777) + "localhost" + u16(7777) +
u32(6) + "myname" + u32(6) + "myname" +
hex("3fbf9add1091c895") + hex("3fbf9add1091c895") +
u32(2) + u32(5)+"room1" + u32(7)+"rack234" + u32(2) + u32(5)+"room1" + u32(7)+"rack234" +
u32(3) + u32(3)+u32(4)+u32(5), u32(3) + u32(3)+u32(4)+u32(5),
// M
hex("97") +
hex("d40202") +
hex("11") +
hex("92") + hex("c409")+"localhost" + hex("cd")+u16(7777) +
hex("c406")+"myname" +
hex("cb" + "3fbf9add1091c895") +
hex("92") + hex("c405")+"room1" + hex("c407")+"rack234" +
hex("93") + hex("030405"),
}, },
// IdTime, empty Address, int32 // IdTime, empty Address, int32
{&NotifyNodeInformation{1504466245.926185, []NodeInfo{ {&NotifyNodeInformation{1504466245.926185, []NodeInfo{
{CLIENT, Address{}, UUID(CLIENT, 1), RUNNING, 1504466245.925599}}}, {CLIENT, Address{}, UUID(CLIENT, 1), RUNNING, 1504466245.925599}}},
// N
hex("41d66b15517b469d") + u32(1) + hex("41d66b15517b469d") + u32(1) +
u8(2) + u32(0) /* <- ø Address */ + hex("e0000001") + u8(2) + u8(2) + u32(0) /* <- ø Address */ + hex("e0000001") + u8(2) +
hex("41d66b15517b3d04"), hex("41d66b15517b3d04"),
// M
hex("92") +
hex("cb" + "41d66b15517b469d") +
hex("91") +
hex("95") +
hex("d40202") +
hex("92" + "c400"+"" + "00") +
hex("d2" + "e0000001") +
hex("d40302") +
hex("cb" + "41d66b15517b3d04"),
}, },
// empty IdTime // empty IdTime
{&NotifyNodeInformation{IdTimeNone, []NodeInfo{}}, hex("ffffffffffffffff") + hex("00000000")}, {&NotifyNodeInformation{IdTimeNone, []NodeInfo{}},
// N
hex("ffffffffffffffff") + hex("00000000"),
// M
hex("92") +
hex("cb" + "fff0000000000000") + // XXX nan/-inf not handled yet
hex("90"),
},
// TODO we need tests for: // TODO we need tests for:
// []varsize + trailing // []varsize + trailing
...@@ -280,7 +363,8 @@ func TestMsgMarshal(t *testing.T) { ...@@ -280,7 +363,8 @@ func TestMsgMarshal(t *testing.T) {
} }
for _, tt := range testv { for _, tt := range testv {
testMsgMarshal(t, tt.msg, tt.encoded) testMsgMarshal(t, 'N', tt.msg, tt.encodedN)
testMsgMarshal(t, 'M', tt.msg, tt.encodedM)
} }
} }
...@@ -288,18 +372,23 @@ func TestMsgMarshal(t *testing.T) { ...@@ -288,18 +372,23 @@ func TestMsgMarshal(t *testing.T) {
// this way we additionally lightly check encode / decode overflow behaviour for all types. // this way we additionally lightly check encode / decode overflow behaviour for all types.
func TestMsgMarshalAllOverflowLightly(t *testing.T) { func TestMsgMarshalAllOverflowLightly(t *testing.T) {
for _, typ := range msgTypeRegistry { for _, typ := range msgTypeRegistry {
for _, enc := range []Encoding{'N', 'M'} {
// zero-value for a type // zero-value for a type
msg := reflect.New(typ).Interface().(Msg) msg := reflect.New(typ).Interface().(Msg)
l := msg.NEOMsgEncodedLen() l := enc.NEOMsgEncodedLen(msg)
zerol := make([]byte, l) zerol := make([]byte, l)
if enc != 'N' { // M-encoding of zero-value is not all zeros
enc.NEOMsgEncode(msg, zerol)
}
// decoding will turn nil slice & map into empty allocated ones. // decoding will turn nil slice & map into empty allocated ones.
// we need it so that reflect.DeepEqual works for msg encode/decode comparison // we need it so that reflect.DeepEqual works for msg encode/decode comparison
n, err := msg.NEOMsgDecode(zerol) n, err := enc.NEOMsgDecode(msg, zerol)
if !(n == l && err == nil) { if !(n == l && err == nil) {
t.Errorf("%v: zero-decode unexpected: %v, %v ; want %v, nil", typ, n, err, l) t.Errorf("%c/%v: zero-decode unexpected: %v, %v ; want %v, nil", enc, typ, n, err, l)
} }
testMsgMarshal(t, msg, string(zerol)) testMsgMarshal(t, enc, msg, string(zerol))
}
} }
} }
...@@ -316,6 +405,8 @@ func TestMsgDecodeLenOverflow(t *testing.T) { ...@@ -316,6 +405,8 @@ func TestMsgDecodeLenOverflow(t *testing.T) {
{&AnswerLockedTransactions{}, u32(0x10000000)}, {&AnswerLockedTransactions{}, u32(0x10000000)},
} }
enc := Encoding('N') // XXX hardcoded XXX + M-variants with big len?
for _, tt := range testv { for _, tt := range testv {
data := []byte(tt.data) data := []byte(tt.data)
func() { func() {
...@@ -325,7 +416,7 @@ func TestMsgDecodeLenOverflow(t *testing.T) { ...@@ -325,7 +416,7 @@ func TestMsgDecodeLenOverflow(t *testing.T) {
} }
}() }()
n, err := tt.msg.NEOMsgDecode(data) n, err := enc.NEOMsgDecode(tt.msg, data)
if !(n == 0 && err == ErrDecodeOverflow) { if !(n == 0 && err == ErrDecodeOverflow) {
t.Errorf("%T: decode %x\nhave: %d, %v\nwant: %d, %v", tt.msg, data, t.Errorf("%T: decode %x\nhave: %d, %v\nwant: %d, %v", tt.msg, data,
n, err, 0, ErrDecodeOverflow) n, err, 0, ErrDecodeOverflow)
......
This diff is collapsed.
This diff is collapsed.
...@@ -387,6 +387,8 @@ func Verify(t *testing.T, f func(*tEnv)) { ...@@ -387,6 +387,8 @@ func Verify(t *testing.T, f func(*tEnv)) {
// TODO verify M=(go|py) x S=(go|py) x ... // TODO verify M=(go|py) x S=(go|py) x ...
// for now we only verify for all combinations of network // for now we only verify for all combinations of network
// TODO verify enc=(M|N)
// for all networks // for all networks
for _, network := range []string{"pipenet", "lonet"} { for _, network := range []string{"pipenet", "lonet"} {
opt := tClusterOptions{ opt := tClusterOptions{
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment