Commit c0d54d50 authored by Kirill Smelkov's avatar Kirill Smelkov

go/neo/proto: Introduce Encoding

Encoding specifies a way to encode/decode NEO messages and packets.
Current way of how messages were encoded is called to be 'N' encoding.

This patch:

- adds proto.Encoding type
- changes MsgEncode and MsgDecode to be methods of Encoding
- renames thigs that are specific to 'N' encoding to have 'N' suffix
- changes tests to run a testcase agains vector of provided encodings.
  That vector is currently only ['N'].
parent 39545b9c
......@@ -122,7 +122,8 @@ import (
//
// It is safe to use NodeLink from multiple goroutines simultaneously.
type NodeLink struct {
peerLink net.Conn // raw conn to peer
peerLink net.Conn // raw conn to peer
enc proto.Encoding // protocol encoding in use ('N')
connMu sync.Mutex
connTab map[uint32]*Conn // connId -> Conn associated with connId
......@@ -151,7 +152,7 @@ type NodeLink struct {
axclosed atomic32 // whether CloseAccept was called
closed atomic32 // whether Close was called
rxbuf rbuf.RingBuf // buffer for reading from peerLink
rxbufN rbuf.RingBuf // buffer for reading from peerLink (N encoding)
// scheduling optimization: whenever serveRecv sends to Conn.rxq
// receiving side must ack here to receive G handoff.
......@@ -246,6 +247,8 @@ const (
// newNodeLink makes a new NodeLink from already established net.Conn .
//
// On the wire messages will be encoded according to enc.
//
// Role specifies how to treat our role on the link - either as client or
// server. The difference in between client and server roles is in:
//
......@@ -258,7 +261,7 @@ const (
//
// Though it is possible to wrap just-established raw connection into NodeLink,
// users should always use Handshake which performs protocol handshaking first.
func newNodeLink(conn net.Conn, role _LinkRole) *NodeLink {
func newNodeLink(conn net.Conn, enc proto.Encoding, role _LinkRole) *NodeLink {
var nextConnId uint32
switch role &^ linkFlagsMask {
case _LinkServer:
......@@ -271,6 +274,7 @@ func newNodeLink(conn net.Conn, role _LinkRole) *NodeLink {
nl := &NodeLink{
peerLink: conn,
enc: enc,
connTab: map[uint32]*Conn{},
nextConnId: nextConnId,
acceptq: make(chan *Conn), // XXX +buf ?
......@@ -792,7 +796,7 @@ func (nl *NodeLink) serveRecv() {
// pkt.ConnId -> Conn
var connId uint32
if err == nil {
connId, _, _, err = pktDecodeHead(pkt)
connId, _, _, err = pktDecodeHead(nl.enc, pkt)
}
// on IO error framing over peerLink becomes broken
......@@ -1040,7 +1044,7 @@ func (c *Conn) sendPkt(pkt *pktBuf) error {
func (c *Conn) sendPkt2(pkt *pktBuf) error {
// connId must be set to one associated with this connection
connID, _, _, err := pktDecodeHead(pkt)
connID, _, _, err := pktDecodeHead(c.link.enc, pkt)
if err != nil {
panic(fmt.Sprintf("Conn.sendPkt: bad packet: %s", err))
}
......@@ -1129,7 +1133,7 @@ func (nl *NodeLink) serveSend() {
// sendPktDirect sends raw packet with appropriate connection ID directly via link.
func (c *Conn) sendPktDirect(pkt *pktBuf) error {
// connId must be set to one associated with this connection
connID, _, _, err := pktDecodeHead(pkt)
connID, _, _, err := pktDecodeHead(c.link.enc, pkt)
if err != nil {
panic(fmt.Sprintf("Conn.sendPkt: bad packet: %s", err))
}
......@@ -1166,7 +1170,7 @@ const dumpio = false
func (nl *NodeLink) sendPkt(pkt *pktBuf) error {
if dumpio {
// XXX -> log
fmt.Printf("%v > %v: %v\n", nl.peerLink.LocalAddr(), nl.peerLink.RemoteAddr(), pkt)
fmt.Printf("%s > %s: %s\n", nl.peerLink.LocalAddr(), nl.peerLink.RemoteAddr(), pktString(nl.enc, pkt))
//defer fmt.Printf("\t-> sendPkt err: %v\n", err)
}
......@@ -1183,8 +1187,29 @@ var ErrPktTooBig = errors.New("packet too big")
// rx error, if any, is returned as is and is analyzed in serveRecv
//
// XXX dup in ZEO.
func (nl *NodeLink) recvPkt() (*pktBuf, error) {
// FIXME if rxbuf is non-empty - first look there for header and then if
func (nl *NodeLink) recvPkt() (pkt *pktBuf, err error) {
switch nl.enc {
case 'N': pkt, err = nl.recvPktN()
default: panic("bug")
}
if dumpio {
// XXX -> log
s := fmt.Sprintf("%s < %s: ", nl.peerLink.LocalAddr(), nl.peerLink.RemoteAddr())
if err != nil {
s += err.Error()
} else {
s += pktString(nl.enc, pkt)
}
fmt.Println(s)
}
return pkt, err
}
func (nl *NodeLink) recvPktN() (*pktBuf, error) {
// FIXME if rxbufN is non-empty - first look there for header and then if
// we know size -> allocate pkt with that size.
pkt := pktAlloc(4096)
// len=4K but cap can be more since pkt is from pool - use all space to buffer reads
......@@ -1194,35 +1219,35 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) {
n := 0 // number of pkt bytes obtained so far
// next packet could be already prefetched in part by previous read
if nl.rxbuf.Len() > 0 {
δn, _ := nl.rxbuf.Read(data[:proto.PktHeaderLen])
if nl.rxbufN.Len() > 0 {
δn, _ := nl.rxbufN.Read(data[:proto.PktHeaderLenN])
n += δn
}
// first read to read pkt header and hopefully rest of packet in 1 syscall
if n < proto.PktHeaderLen {
δn, err := io.ReadAtLeast(nl.peerLink, data[n:], proto.PktHeaderLen - n)
if n < proto.PktHeaderLenN {
δn, err := io.ReadAtLeast(nl.peerLink, data[n:], proto.PktHeaderLenN - n)
if err != nil {
return nil, err
}
n += δn
}
pkth := pkt.Header()
pkth := pkt.HeaderN()
msgLen := packed.Ntoh32(pkth.MsgLen)
if msgLen > proto.PktMaxSize - proto.PktHeaderLen {
if msgLen > proto.PktMaxSize - proto.PktHeaderLenN {
return nil, ErrPktTooBig
}
pktLen := int(proto.PktHeaderLen + msgLen) // whole packet length
pktLen := int(proto.PktHeaderLenN + msgLen) // whole packet length
// resize data if we don't have enough room in it
data = xbytes.Resize(data, pktLen)
data = data[:cap(data)]
// we might have more data already prefetched in rxbuf
if nl.rxbuf.Len() > 0 {
δn, _ := nl.rxbuf.Read(data[n:pktLen])
// we might have more data already prefetched in rxbufN
if nl.rxbufN.Len() > 0 {
δn, _ := nl.rxbufN.Read(data[n:pktLen])
n += δn
}
......@@ -1235,20 +1260,15 @@ func (nl *NodeLink) recvPkt() (*pktBuf, error) {
n += δn
}
// put overread data into rxbuf for next reader
// put overread data into rxbufN for next reader
if n > pktLen {
nl.rxbuf.Write(data[pktLen:n])
nl.rxbufN.Write(data[pktLen:n])
}
// fixup data/pkt
data = data[:pktLen]
pkt.data = data
if dumpio {
// XXX -> log
fmt.Printf("%v < %v: %v\n", nl.peerLink.LocalAddr(), nl.peerLink.RemoteAddr(), pkt)
}
return pkt, nil
}
......@@ -1320,29 +1340,51 @@ func (c *Conn) err(op string, e error) error {
// ---- exchange of messages ----
// pktEncode allocates pktBuf and encodes msg into it.
func pktEncode(connId uint32, msg proto.Msg) *pktBuf {
l := proto.MsgEncodedLen(msg)
buf := pktAlloc(proto.PktHeaderLen + l)
func pktEncode(e proto.Encoding, connId uint32, msg proto.Msg) *pktBuf {
switch e {
case 'N': return pktEncodeN(connId, msg)
default: panic("bug")
}
}
h := buf.Header()
h.ConnId = packed.Hton32(connId)
// pktDecodeHead decodes header of a packet.
func pktDecodeHead(e proto.Encoding, pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) {
switch e {
case 'N': connID, msgCode, payload, err = pktDecodeHeadN(pkt)
default: panic("bug")
}
if err != nil {
err = fmt.Errorf("%c: decode header: %s", e, err)
}
return connID, msgCode, payload, err
}
func pktEncodeN(connId uint32, msg proto.Msg) *pktBuf {
const enc = proto.Encoding('N')
l := enc.MsgEncodedLen(msg)
buf := pktAlloc(proto.PktHeaderLenN + l)
h := buf.HeaderN()
h.ConnId = packed.Hton32(connId)
h.MsgCode = packed.Hton16(proto.MsgCode(msg))
h.MsgLen = packed.Hton32(uint32(l)) // XXX casting: think again
h.MsgLen = packed.Hton32(uint32(l)) // XXX casting: think again
proto.MsgEncode(msg, buf.Payload())
enc.MsgEncode(msg, buf.PayloadN())
return buf
}
// pktDecodeHead decodes header of a packet.
func pktDecodeHead(pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) {
if len(pkt.data) < proto.PktHeaderLen {
func pktDecodeHeadN(pkt *pktBuf) (connID uint32, msgCode uint16, payload []byte, err error) {
if len(pkt.data) < proto.PktHeaderLenN {
return 0, 0, nil, fmt.Errorf("packet too short")
}
pkth := pkt.Header()
pkth := pkt.HeaderN()
connID = packed.Ntoh32(pkth.ConnId)
msgCode = packed.Ntoh16(pkth.MsgCode)
msgLen := packed.Ntoh32(pkth.MsgLen)
payload = pkt.Payload()
payload = pkt.PayloadN()
if len(payload) != int(msgLen) {
return 0, 0, nil, fmt.Errorf("len(payload) != msgLen")
}
......@@ -1359,7 +1401,7 @@ func (c *Conn) Recv() (proto.Msg, error) {
defer pkt.Free()
// decode packet
_, msgCode, payload, err := pktDecodeHead(pkt)
_, msgCode, payload, err := pktDecodeHead(c.link.enc, pkt)
if err != nil {
return nil, err
}
......@@ -1376,7 +1418,7 @@ func (c *Conn) Recv() (proto.Msg, error) {
// msg := reflect.NewAt(msgType, bufAlloc(msgType.Size())
_, err = proto.MsgDecode(msg, payload)
_, err = c.link.enc.MsgDecode(msg, payload)
if err != nil {
return nil, c.err("decode", err) // XXX "decode:" is already in ErrDecodeOverflow
}
......@@ -1390,14 +1432,14 @@ func (c *Conn) Recv() (proto.Msg, error) {
//
// it is ok to call sendMsg in parallel with serveSend. XXX link to sendPktDirect for rationale?
func (link *NodeLink) sendMsg(connId uint32, msg proto.Msg) error {
buf := pktEncode(connId, msg)
buf := pktEncode(link.enc, connId, msg)
return link.sendPkt(buf) // XXX more context in err? (msg type)
// FIXME ^^^ shutdown whole link on error
}
// Send sends message over the connection.
func (c *Conn) Send(msg proto.Msg) error {
buf := pktEncode(c.connId, msg)
buf := pktEncode(c.link.enc, c.connId, msg)
return c.sendPkt(buf) // XXX more context in err? (msg type)
}
......@@ -1421,14 +1463,14 @@ func (c *Conn) Expect(msgv ...proto.Msg) (which int, err error) {
defer pkt.Free()
// decode packet
_, msgCode, payload, err := pktDecodeHead(pkt)
_, msgCode, payload, err := pktDecodeHead(c.link.enc, pkt)
if err != nil {
return -1, err
}
for i, msg := range msgv {
if proto.MsgCode(msg) == msgCode {
_, err := proto.MsgDecode(msg, payload)
_, err := c.link.enc.MsgDecode(msg, payload)
if err != nil {
return -1, c.err("decode", err)
}
......
......@@ -22,6 +22,7 @@ package neonet
import (
"bytes"
"context"
"fmt"
"io"
"net"
"reflect"
......@@ -45,16 +46,25 @@ import (
// T is neonet testing environment.
type T struct {
*testing.T
enc proto.Encoding // encoding to use for messages exchange
}
// Verify tests f for all possible environments.
func Verify(t *testing.T, f func(*T)) {
f(&T{t})
// for each encoding
for _, enc := range []proto.Encoding{'N'} {
t.Run(fmt.Sprintf("enc=%c", enc), func(t *testing.T) {
f(&T{t, enc})
})
}
}
// bin returns payload for raw binary data as it would-be encoded in t.
// bin returns payload for raw binary data as it would-be encoded by t.enc .
func (t *T) bin(data string) []byte {
return []byte(data)
switch t.enc {
case 'N': return []byte(data)
default: panic("bug")
}
}
......@@ -118,26 +128,32 @@ func xconnError(err error) error {
}
// Prepare pktBuf with content.
func _mkpkt(connid uint32, msgcode uint16, payload []byte) *pktBuf {
pkt := &pktBuf{make([]byte, proto.PktHeaderLen+len(payload))}
h := pkt.Header()
h.ConnId = packed.Hton32(connid)
h.MsgCode = packed.Hton16(msgcode)
h.MsgLen = packed.Hton32(uint32(len(payload)))
copy(pkt.Payload(), payload)
return pkt
func _mkpkt(enc proto.Encoding, connid uint32, msgcode uint16, payload []byte) *pktBuf {
switch enc {
case 'N':
pkt := &pktBuf{make([]byte, proto.PktHeaderLenN+len(payload))}
h := pkt.HeaderN()
h.ConnId = packed.Hton32(connid)
h.MsgCode = packed.Hton16(msgcode)
h.MsgLen = packed.Hton32(uint32(len(payload)))
copy(pkt.PayloadN(), payload)
return pkt
default:
panic("bug")
}
}
func (c *Conn) mkpkt(msgcode uint16, payload []byte) *pktBuf {
// in Conn exchange connid is automatically set by Conn.sendPkt
return _mkpkt(c.connId, msgcode, payload)
return _mkpkt(c.link.enc, c.connId, msgcode, payload)
}
// Verify pktBuf is as expected.
func (t *T) xverifyPkt(pkt *pktBuf, connid uint32, msgcode uint16, payload []byte) {
errv := xerr.Errorv{}
pktConnID, pktMsgCode, pktPayload, err := pktDecodeHead(pkt)
pktConnID, pktMsgCode, pktPayload, err := pktDecodeHead(t.enc, pkt)
exc.Raiseif(err)
// TODO include caller location
......@@ -157,8 +173,8 @@ func (t *T) xverifyPkt(pkt *pktBuf, connid uint32, msgcode uint16, payload []byt
// Verify pktBuf to match expected message.
func (t *T) xverifyPktMsg(pkt *pktBuf, connid uint32, msg proto.Msg) {
data := make([]byte, proto.MsgEncodedLen(msg))
proto.MsgEncode(msg, data)
data := make([]byte, t.enc.MsgEncodedLen(msg))
t.enc.MsgEncode(msg, data)
t.xverifyPkt(pkt, connid, proto.MsgCode(msg), data)
}
......@@ -176,11 +192,11 @@ func tdelay() {
time.Sleep(1 * time.Millisecond)
}
// create NodeLinks connected via net.Pipe
// create NodeLinks connected via net.Pipe; messages are encoded via t.enc.
func (t *T) _nodeLinkPipe(flags1, flags2 _LinkRole) (nl1, nl2 *NodeLink) {
node1, node2 := net.Pipe()
nl1 = newNodeLink(node1, _LinkClient|flags1)
nl2 = newNodeLink(node2, _LinkServer|flags2)
nl1 = newNodeLink(node1, t.enc, _LinkClient|flags1)
nl2 = newNodeLink(node2, t.enc, _LinkServer|flags2)
return nl1, nl2
}
......@@ -289,7 +305,7 @@ func _TestNodeLink(t *T) {
okch := make(chan int, 2)
gox(wg, func(_ context.Context) {
// send ping; wait for pong
pkt := _mkpkt(1, 2, b("ping"))
pkt := _mkpkt(t.enc, 1, 2, b("ping"))
xsendPkt(nl1, pkt)
pkt = xrecvPkt(nl1)
t.xverifyPkt(pkt, 3, 4, b("pong"))
......@@ -299,7 +315,7 @@ func _TestNodeLink(t *T) {
// wait for ping; send pong
pkt = xrecvPkt(nl2)
t.xverifyPkt(pkt, 1, 2, b("ping"))
pkt = _mkpkt(3, 4, b("pong"))
pkt = _mkpkt(t.enc, 3, 4, b("pong"))
xsendPkt(nl2, pkt)
okch <- 2
})
......@@ -614,7 +630,7 @@ func _TestNodeLink(t *T) {
gox(wg, func(_ context.Context) {
pkt := xrecvPkt(c)
_, msgCode, _, err := pktDecodeHead(pkt)
_, msgCode, _, err := pktDecodeHead(t.enc, pkt)
exc.Raiseif(err)
x := replyOrder[msgCode]
......
......@@ -81,7 +81,8 @@ func handshakeClient(ctx context.Context, conn net.Conn, version uint32) (*NodeL
if err != nil {
return nil, err
}
return newNodeLink(conn, _LinkClient), nil
enc := proto.Encoding('N')
return newNodeLink(conn, enc, _LinkClient), nil
}
// handshakeServer implements server-side NEO protocol handshake just after raw
......@@ -96,7 +97,8 @@ func handshakeServer(ctx context.Context, conn net.Conn, version uint32) (*NodeL
if err != nil {
return nil, err
}
return newNodeLink(conn, _LinkServer), nil
enc := proto.Encoding('N')
return newNodeLink(conn, enc, _LinkServer), nil
}
func _handshakeClient(ctx context.Context, conn net.Conn, version uint32) (err error) {
......
......@@ -38,16 +38,16 @@ type pktBuf struct {
data []byte // whole packet data including all headers
}
// Header returns pointer to packet header.
func (pkt *pktBuf) Header() *proto.PktHeader {
// NOTE no need to check len(.data) < PktHeader:
// .data is always allocated with cap >= PktHeaderLen.
return (*proto.PktHeader)(unsafe.Pointer(&pkt.data[0]))
// HeaderN returns pointer to packet header in 'N'-encoding.
func (pkt *pktBuf) HeaderN() *proto.PktHeaderN {
// NOTE no need to check len(.data) < PktHeaderN:
// .data is always allocated with cap >= PktHeaderLenN.
return (*proto.PktHeaderN)(unsafe.Pointer(&pkt.data[0]))
}
// Payload returns []byte representing packet payload.
func (pkt *pktBuf) Payload() []byte {
return pkt.data[proto.PktHeaderLen:]
// PayloadN returns []byte representing packet payload in 'N'-encoding.
func (pkt *pktBuf) PayloadN() []byte {
return pkt.data[proto.PktHeaderLenN:]
}
// ---- pktBuf freelist ----
......@@ -59,11 +59,11 @@ var pktBufPool = sync.Pool{New: func() interface{} {
// pktAlloc allocates pktBuf with len=n.
func pktAlloc(n int) *pktBuf {
// make sure cap >= PktHeaderLen.
// see Header for why
// make sure cap >= PktHeaderLenN.
// see HeaderN for why
l := n
if l < proto.PktHeaderLen {
l = proto.PktHeaderLen
if l < proto.PktHeaderLenN {
l = proto.PktHeaderLenN
}
pkt := pktBufPool.Get().(*pktBuf)
pkt.data = xbytes.Realloc(pkt.data, l)[:n]
......@@ -78,9 +78,9 @@ func (pkt *pktBuf) Free() {
// ---- pktBuf dump ----
// String dumps a packet in human-readable form.
func (pkt *pktBuf) String() string {
connID, msgCode, payload, err := pktDecodeHead(pkt)
// pktString dumps a packet in human-readable form.
func pktString(e proto.Encoding, pkt *pktBuf) string {
connID, msgCode, payload, err := pktDecodeHead(e, pkt)
if err != nil {
return fmt.Sprintf("(%s) % x", err, pkt.data)
}
......@@ -95,7 +95,7 @@ func (pkt *pktBuf) String() string {
// XXX dup wrt Conn.Recv
msg := reflect.New(msgType).Interface().(proto.Msg)
n, err := proto.MsgDecode(msg, payload)
n, err := e.MsgDecode(msg, payload)
if err != nil {
s += fmt.Sprintf(" (%s) %v; [%d]: % x", msgType.Name(), err, len(payload), payload)
} else {
......
......@@ -24,7 +24,7 @@
// ID of subconnection multiplexed on top of the underlying link, carried
// message code and message data.
//
// PktHeader describes packet header structure.
// PktHeaderN describes packet header structure in 'N' encoding.
//
// Messages are represented by corresponding types that all implement Msg interface.
//
......@@ -79,8 +79,8 @@ const (
// the high order byte 0 is different from TLS Handshake (0x16).
Version = 6
// length of packet header
PktHeaderLen = 10 // = unsafe.Sizeof(PktHeader{}), but latter gives typed constant (uintptr)
// length of packet header in 'N'-encoding
PktHeaderLenN = 10 // = unsafe.Sizeof(PktHeaderN{}), but latter gives typed constant (uintptr)
// packets larger than PktMaxSize are not allowed.
// this helps to avoid out-of-memory error on packets with corrupt message len.
......@@ -95,12 +95,12 @@ const (
INVALID_OID zodb.Oid = 1<<64 - 1
)
// PktHeader represents header of a raw packet.
// PktHeaderN represents header of a raw packet in 'N'-encoding.
//
// A packet contains connection ID and message.
//
//neo:proto typeonly
type PktHeader struct {
type PktHeaderN struct {
ConnId packed.BE32 // NOTE is .msgid in py
MsgCode packed.BE16 // payload message code
MsgLen packed.BE32 // payload message length (excluding packet header)
......@@ -114,33 +114,50 @@ type Msg interface {
// on the wire.
neoMsgCode() uint16
// neoMsgEncodedLen returns how much space is needed to encode current message payload.
neoMsgEncodedLen() int
// neoMsgEncode encodes current message state into buf.
// for encoding E:
//
// - neoMsgEncodedLen<E> returns how much space is needed to encode current message payload via E encoding.
//
// - neoMsgEncode<E> encodes current message state into buf via E encoding.
//
// len(buf) must be >= neoMsgEncodedLen().
neoMsgEncode(buf []byte)
// len(buf) must be >= neoMsgEncodedLen<E>().
//
// - neoMsgDecode<E> decodes data via E encoding into message in-place.
// N encoding (original struct-based encoding)
neoMsgEncodedLenN() int
neoMsgEncodeN(buf []byte)
neoMsgDecodeN(data []byte) (nread int, err error)
// neoMsgDecode decodes data into message in-place.
neoMsgDecode(data []byte) (nread int, err error)
}
// MsgEncodedLen returns how much space is needed to encode msg payload.
func MsgEncodedLen(msg Msg) int {
return msg.neoMsgEncodedLen()
// Encoding represents messages encoding.
type Encoding byte
// MsgEncodedLen returns how much space is needed to encode msg payload via encoding e.
func (e Encoding) MsgEncodedLen(msg Msg) int {
switch e {
default: panic("bug")
case 'N': return msg.neoMsgEncodedLenN()
}
}
// MsgEncode encodes msg state into buf.
// MsgEncode encodes msg state into buf via encoding e.
//
// len(buf) must be >= MsgEncodedLen(m).
func MsgEncode(msg Msg, buf []byte) {
msg.neoMsgEncode(buf)
// len(buf) must be >= e.MsgEncodedLen(m).
func (e Encoding) MsgEncode(msg Msg, buf []byte) {
switch e {
default: panic("bug")
case 'N': msg.neoMsgEncodeN(buf)
}
}
// MsgDecode decodes data into msg in-place.
func MsgDecode(msg Msg, data []byte) (nread int, err error) {
return msg.neoMsgDecode(data)
// MsgDecode decodes data via encoding e into msg in-place.
func (e Encoding) MsgDecode(msg Msg, data []byte) (nread int, err error) {
switch e {
default: panic("bug")
case 'N': return msg.neoMsgDecodeN(data)
}
}
......@@ -266,7 +283,7 @@ type Address struct {
}
// NOTE if Host == "" -> Port not added to wire (see py.PAddress):
func (a *Address) neoEncodedLen() int {
func (a *Address) neoEncodedLenN() int {
l := string_neoEncodedLen(a.Host)
if a.Host != "" {
l += 2
......@@ -274,7 +291,7 @@ func (a *Address) neoEncodedLen() int {
return l
}
func (a *Address) neoEncode(b []byte) int {
func (a *Address) neoEncodeN(b []byte) int {
n := string_neoEncode(a.Host, b[0:])
if a.Host != "" {
binary.BigEndian.PutUint16(b[n:], a.Port)
......@@ -283,7 +300,7 @@ func (a *Address) neoEncode(b []byte) int {
return n
}
func (a *Address) neoDecode(b []byte) (uint64, bool) {
func (a *Address) neoDecodeN(b []byte) (uint64, bool) {
n, ok := string_neoDecode(&a.Host, b)
if !ok {
return 0, false
......@@ -312,11 +329,11 @@ type PTid uint64
// IdTime represents time of identification.
type IdTime float64
func (t IdTime) neoEncodedLen() int {
func (t IdTime) neoEncodedLenN() int {
return 8
}
func (t IdTime) neoEncode(b []byte) int {
func (t IdTime) neoEncodeN(b []byte) int {
// use -inf as value for no data (NaN != NaN -> hard to use NaN in tests)
// NOTE neo/py uses None for "no data"; we use 0 for "no data" to avoid pointer
tt := float64(t)
......@@ -327,7 +344,7 @@ func (t IdTime) neoEncode(b []byte) int {
return 8
}
func (t *IdTime) neoDecode(data []byte) (uint64, bool) {
func (t *IdTime) neoDecodeN(data []byte) (uint64, bool) {
if len(data) < 8 {
return 0, false
}
......@@ -1210,13 +1227,13 @@ type FlushLog struct {}
// ---- runtime support for protogen and custom codecs ----
// customCodec is the interface that is implemented by types with custom encodings.
// customCodecN is the interface that is implemented by types with custom N encodings.
//
// its semantic is very similar to Msg.
type customCodec interface {
neoEncodedLen() int
neoEncode(buf []byte) (nwrote int)
neoDecode(data []byte) (nread uint64, ok bool) // XXX uint64 or int here?
type customCodecN interface {
neoEncodedLenN() int
neoEncodeN(buf []byte) (nwrote int)
neoDecodeN(data []byte) (nread uint64, ok bool) // XXX uint64 or int here?
}
func byte2bool(b byte) bool {
......
......@@ -68,42 +68,42 @@ func u64(v uint64) string {
return string(b[:])
}
func TestPktHeader(t *testing.T) {
// make sure PktHeader is really packed and its size matches PktHeaderLen
if unsafe.Sizeof(PktHeader{}) != 10 {
t.Fatalf("sizeof(PktHeader) = %v ; want 10", unsafe.Sizeof(PktHeader{}))
func TestPktHeaderN(t *testing.T) {
// make sure PktHeaderN is really packed and its size matches PktHeaderLenN
if unsafe.Sizeof(PktHeaderN{}) != 10 {
t.Fatalf("sizeof(PktHeaderN) = %v ; want 10", unsafe.Sizeof(PktHeaderN{}))
}
if unsafe.Sizeof(PktHeader{}) != PktHeaderLen {
t.Fatalf("sizeof(PktHeader) = %v ; want %v", unsafe.Sizeof(PktHeader{}), PktHeaderLen)
if unsafe.Sizeof(PktHeaderN{}) != PktHeaderLenN {
t.Fatalf("sizeof(PktHeaderN) = %v ; want %v", unsafe.Sizeof(PktHeaderN{}), PktHeaderLenN)
}
}
// test marshalling for one message type
func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
func testMsgMarshal(t *testing.T, enc Encoding, msg Msg, encoded string) {
typ := reflect.TypeOf(msg).Elem() // type of *msg
msg2 := reflect.New(typ).Interface().(Msg)
defer func() {
if e := recover(); e != nil {
t.Errorf("%v: panic ↓↓↓:", typ)
t.Errorf("%c/%v: panic ↓↓↓:", enc, typ)
panic(e) // to show traceback
}
}()
// msg.encode() == expected
msgCode := msg.neoMsgCode()
n := MsgEncodedLen(msg)
n := enc.MsgEncodedLen(msg)
msgType := MsgType(msgCode)
if msgType != typ {
t.Errorf("%v: msgCode = %v which corresponds to %v", typ, msgCode, msgType)
t.Errorf("%c/%v: msgCode = %v which corresponds to %v", enc, typ, msgCode, msgType)
}
if n != len(encoded) {
t.Errorf("%v: encodedLen = %v ; want %v", typ, n, len(encoded))
t.Errorf("%c/%v: encodedLen = %v ; want %v", enc, typ, n, len(encoded))
}
buf := make([]byte, n)
MsgEncode(msg, buf)
enc.MsgEncode(msg, buf)
if string(buf) != encoded {
t.Errorf("%v: encode result unexpected:", typ)
t.Errorf("%c/%v: encode result unexpected:", enc, typ)
t.Errorf("\thave: %s", hexpkg.EncodeToString(buf))
t.Errorf("\twant: %s", hexpkg.EncodeToString([]byte(encoded)))
}
......@@ -112,7 +112,7 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
for l := len(buf) - 1; l >= 0; l-- {
func() {
defer func() {
subj := fmt.Sprintf("%v: encode(buf[:encodedLen-%v])", typ, len(encoded)-l)
subj := fmt.Sprintf("%c/%v: encode(buf[:encodedLen-%v])", enc, typ, len(encoded)-l)
e := recover()
if e == nil {
t.Errorf("%s did not panic", subj)
......@@ -131,29 +131,29 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
}
}()
MsgEncode(msg, buf[:l])
enc.MsgEncode(msg, buf[:l])
}()
}
// msg.decode() == expected
data := []byte(encoded + "noise")
n, err := MsgDecode(msg2, data)
n, err := enc.MsgDecode(msg2, data)
if err != nil {
t.Errorf("%v: decode error %v", typ, err)
t.Errorf("%c/%v: decode error %v", enc, typ, err)
}
if n != len(encoded) {
t.Errorf("%v: nread = %v ; want %v", typ, n, len(encoded))
t.Errorf("%c/%v: nread = %v ; want %v", enc, typ, n, len(encoded))
}
if !reflect.DeepEqual(msg2, msg) {
t.Errorf("%v: decode result unexpected: %v ; want %v", typ, msg2, msg)
t.Errorf("%c/%v: decode result unexpected: %v ; want %v", enc, typ, msg2, msg)
}
// decode must detect buffer overflow
for l := len(encoded) - 1; l >= 0; l-- {
n, err = MsgDecode(msg2, data[:l])
n, err = enc.MsgDecode(msg2, data[:l])
if !(n == 0 && err == ErrDecodeOverflow) {
t.Errorf("%v: decode overflow not detected on [:%v]", typ, l)
t.Errorf("%c/%v: decode overflow not detected on [:%v]", enc, typ, l)
}
}
......@@ -162,8 +162,8 @@ func testMsgMarshal(t *testing.T, msg Msg, encoded string) {
// test encoding/decoding of messages
func TestMsgMarshal(t *testing.T) {
var testv = []struct {
msg Msg
encoded string // []byte
msg Msg
encodedN string // []byte
}{
// empty
{&Ping{}, ""},
......@@ -198,6 +198,7 @@ func TestMsgMarshal(t *testing.T) {
},
},
// N
hex("0102030405060708") +
hex("00000022") +
hex("00000003") +
......@@ -219,6 +220,7 @@ func TestMsgMarshal(t *testing.T) {
5: {4, 3, true},
}},
// N
u32(4) +
u64(1) + u64(1) + u64(0) + hex("00") +
u64(2) + u64(7) + u64(1) + hex("01") +
......@@ -238,6 +240,7 @@ func TestMsgMarshal(t *testing.T) {
MaxTID: 128,
},
// N
u32(4) +
u32(1) + u32(7) +
u32(2) + u32(9) +
......@@ -248,12 +251,13 @@ func TestMsgMarshal(t *testing.T) {
// uint32, []uint32
{&PartitionCorrupted{7, []NodeUUID{1, 3, 9, 4}},
// N
u32(7) + u32(4) + u32(1) + u32(3) + u32(9) + u32(4),
},
// uint32, Address, string, IdTime
{&RequestIdentification{CLIENT, 17, Address{"localhost", 7777}, "myname", 0.12345678, []string{"room1", "rack234"}, []uint32{3,4,5} },
// N
u8(2) + u32(17) + u32(9) +
"localhost" + u16(7777) +
u32(6) + "myname" +
......@@ -265,14 +269,17 @@ func TestMsgMarshal(t *testing.T) {
// IdTime, empty Address, int32
{&NotifyNodeInformation{1504466245.926185, []NodeInfo{
{CLIENT, Address{}, UUID(CLIENT, 1), RUNNING, 1504466245.925599}}},
// N
hex("41d66b15517b469d") + u32(1) +
u8(2) + u32(0) /* <- ø Address */ + hex("e0000001") + u8(2) +
hex("41d66b15517b3d04"),
},
// empty IdTime
{&NotifyNodeInformation{IdTimeNone, []NodeInfo{}}, hex("ffffffffffffffff") + hex("00000000")},
{&NotifyNodeInformation{IdTimeNone, []NodeInfo{}},
// N
hex("ffffffffffffffff") + hex("00000000"),
},
// TODO we need tests for:
// []varsize + trailing
......@@ -280,7 +287,7 @@ func TestMsgMarshal(t *testing.T) {
}
for _, tt := range testv {
testMsgMarshal(t, tt.msg, tt.encoded)
testMsgMarshal(t, 'N', tt.msg, tt.encodedN)
}
}
......@@ -288,23 +295,27 @@ func TestMsgMarshal(t *testing.T) {
// this way we additionally lightly check encode / decode overflow behaviour for all types.
func TestMsgMarshalAllOverflowLightly(t *testing.T) {
for _, typ := range msgTypeRegistry {
// zero-value for a type
msg := reflect.New(typ).Interface().(Msg)
l := MsgEncodedLen(msg)
zerol := make([]byte, l)
// decoding will turn nil slice & map into empty allocated ones.
// we need it so that reflect.DeepEqual works for msg encode/decode comparison
n, err := MsgDecode(msg, zerol)
if !(n == l && err == nil) {
t.Errorf("%v: zero-decode unexpected: %v, %v ; want %v, nil", typ, n, err, l)
}
for _, enc := range []Encoding{'N'} {
// zero-value for a type
msg := reflect.New(typ).Interface().(Msg)
l := enc.MsgEncodedLen(msg)
zerol := make([]byte, l)
// decoding will turn nil slice & map into empty allocated ones.
// we need it so that reflect.DeepEqual works for msg encode/decode comparison
n, err := enc.MsgDecode(msg, zerol)
if !(n == l && err == nil) {
t.Errorf("%c/%v: zero-decode unexpected: %v, %v ; want %v, nil", enc, typ, n, err, l)
}
testMsgMarshal(t, msg, string(zerol))
testMsgMarshal(t, enc, msg, string(zerol))
}
}
}
// Verify overflow handling on decode len checks
func TestMsgDecodeLenOverflow(t *testing.T) {
// Verify overflow handling on decodeN len checks
func TestMsgDecodeLenOverflowN(t *testing.T) {
enc := Encoding('N')
var testv = []struct {
msg Msg // of type to decode into
data string // []byte - tricky data to exercise decoder u32 len checks overflow
......@@ -325,7 +336,7 @@ func TestMsgDecodeLenOverflow(t *testing.T) {
}
}()
n, err := MsgDecode(tt.msg, data)
n, err := enc.MsgDecode(tt.msg, data)
if !(n == 0 && err == ErrDecodeOverflow) {
t.Errorf("%T: decode %x\nhave: %d, %v\nwant: %d, %v", tt.msg, data,
n, err, 0, ErrDecodeOverflow)
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment