Commit 9ae505ad authored by Kirill Smelkov's avatar Kirill Smelkov

X neonet: Prevent DOS with too-big MsgPack frame

parent e1c25b29
...@@ -105,7 +105,6 @@ import ( ...@@ -105,7 +105,6 @@ import (
"lab.nexedi.com/kirr/neo/go/neo/internal/msgpack" "lab.nexedi.com/kirr/neo/go/neo/internal/msgpack"
"lab.nexedi.com/kirr/neo/go/neo/proto" "lab.nexedi.com/kirr/neo/go/neo/proto"
"github.com/philhofer/fwd"
"github.com/someonegg/gocontainer/rbuf" "github.com/someonegg/gocontainer/rbuf"
"github.com/tinylib/msgp/msgp" "github.com/tinylib/msgp/msgp"
...@@ -159,6 +158,7 @@ type NodeLink struct { ...@@ -159,6 +158,7 @@ type NodeLink struct {
rxbufN rbuf.RingBuf // buffer for reading from peerLink (N encoding) rxbufN rbuf.RingBuf // buffer for reading from peerLink (N encoding)
rxbufM *msgp.Reader // ----//---- (M encoding) rxbufM *msgp.Reader // ----//---- (M encoding)
rxbufMlimit *io.LimitedReader // limiter inserted inbetween rxbufM and peerLink
// scheduling optimization: whenever serveRecv sends to Conn.rxq // scheduling optimization: whenever serveRecv sends to Conn.rxq
// receiving side must ack here to receive G handoff. // receiving side must ack here to receive G handoff.
...@@ -275,7 +275,7 @@ const ( ...@@ -275,7 +275,7 @@ const (
// users should always use Handshake which performs protocol handshaking first. // users should always use Handshake which performs protocol handshaking first.
// //
// rxbuf if != nil indicates what was already read-buffered from conn. // rxbuf if != nil indicates what was already read-buffered from conn.
func newNodeLink(conn net.Conn, enc proto.Encoding, role _LinkRole, rxbuf *fwd.Reader) *NodeLink { func newNodeLink(conn net.Conn, enc proto.Encoding, role _LinkRole, rxbuf *xbufReader) *NodeLink {
var nextConnId uint32 var nextConnId uint32
switch role &^ linkFlagsMask { switch role &^ linkFlagsMask {
case _LinkServer: case _LinkServer:
...@@ -299,7 +299,7 @@ func newNodeLink(conn net.Conn, enc proto.Encoding, role _LinkRole, rxbuf *fwd.R ...@@ -299,7 +299,7 @@ func newNodeLink(conn net.Conn, enc proto.Encoding, role _LinkRole, rxbuf *fwd.R
} }
if rxbuf == nil { if rxbuf == nil {
rxbuf = fwd.NewReader(conn) rxbuf = newXBufReader(conn, 0)
} }
switch enc { switch enc {
case 'N': case 'N':
...@@ -311,7 +311,9 @@ func newNodeLink(conn net.Conn, enc proto.Encoding, role _LinkRole, rxbuf *fwd.R ...@@ -311,7 +311,9 @@ func newNodeLink(conn net.Conn, enc proto.Encoding, role _LinkRole, rxbuf *fwd.R
nl.rxbufN.Write(b) nl.rxbufN.Write(b)
case 'M': case 'M':
nl.rxbufM = &msgp.Reader{R: rxbuf} nl.rxbufM = &msgp.Reader{R: rxbuf.Reader}
nl.rxbufMlimit = &rxbuf.Limit
nl.rxbufMlimit.N = 0 // reads will fail unless .N is explicitly reset
default: default:
panic("bug") panic("bug")
} }
...@@ -1309,8 +1311,16 @@ func (nl *NodeLink) recvPktN() (*pktBuf, error) { ...@@ -1309,8 +1311,16 @@ func (nl *NodeLink) recvPktN() (*pktBuf, error) {
func (nl *NodeLink) recvPktM() (*pktBuf, error) { func (nl *NodeLink) recvPktM() (*pktBuf, error) {
pkt := pktAlloc(4096) pkt := pktAlloc(4096)
mraw := msgp.Raw(pkt.data) mraw := msgp.Raw(pkt.data)
err := mraw.DecodeMsg(nl.rxbufM) // XXX limit size of one packet to proto.PktMaxSize (= UNPACK_BUFFER_SIZE in NEO/py speak)
// limit size of one packet to proto.PktMaxSize
// we don't care if it will be slightly more with what is already buffered
nl.rxbufMlimit.N = proto.PktMaxSize
err := mraw.DecodeMsg(nl.rxbufM)
if err != nil { if err != nil {
if nl.rxbufMlimit.N <= 0 {
err = ErrPktTooBig
}
return nil, err return nil, err
} }
pkt.data = []byte(mraw) pkt.data = []byte(mraw)
......
// Copyright (C) 2016-2018 Nexedi SA and Contributors. // Copyright (C) 2016-2021 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com> // Kirill Smelkov <kirr@nexedi.com>
// //
// This program is free software: you can Use, Study, Modify and Redistribute // This program is free software: you can Use, Study, Modify and Redistribute
...@@ -20,7 +20,13 @@ ...@@ -20,7 +20,13 @@
package neonet package neonet
// syntax sugar for atomic load/store to raise signal/noise in logic // syntax sugar for atomic load/store to raise signal/noise in logic
import "sync/atomic" import (
"io"
"sync/atomic"
"github.com/philhofer/fwd"
)
type atomic32 struct { type atomic32 struct {
v int32 // struct member so `var a atomic32; if a == 0 ...` does not work v int32 // struct member so `var a atomic32; if a == 0 ...` does not work
...@@ -37,3 +43,16 @@ func (a *atomic32) Set(v int32) { ...@@ -37,3 +43,16 @@ func (a *atomic32) Set(v int32) {
func (a *atomic32) Add(δ int32) int32 { func (a *atomic32) Add(δ int32) int32 {
return atomic.AddInt32(&a.v, δ) return atomic.AddInt32(&a.v, δ)
} }
// xbufReader provides fwd.Reader with io.LimitedReader inserted underneath it.
type xbufReader struct {
*fwd.Reader
Limit io.LimitedReader // .Reader reads through .limit
}
func newXBufReader(r io.Reader, n int64) *xbufReader {
rxbuf := &xbufReader{Limit: io.LimitedReader{R: r, N: n}}
rxbuf.Reader = fwd.NewReader(&rxbuf.Limit)
return rxbuf
}
...@@ -28,7 +28,6 @@ import ( ...@@ -28,7 +28,6 @@ import (
"io" "io"
"net" "net"
"github.com/philhofer/fwd"
"github.com/tinylib/msgp/msgp" "github.com/tinylib/msgp/msgp"
"lab.nexedi.com/kirr/go123/xerr" "lab.nexedi.com/kirr/go123/xerr"
...@@ -104,14 +103,14 @@ func handshakeServer(ctx context.Context, conn net.Conn, version uint32) (*NodeL ...@@ -104,14 +103,14 @@ func handshakeServer(ctx context.Context, conn net.Conn, version uint32) (*NodeL
return newNodeLink(conn, enc, _LinkServer, rxbuf), nil return newNodeLink(conn, enc, _LinkServer, rxbuf), nil
} }
func _handshakeClient(ctx context.Context, conn net.Conn, version uint32, encPrefer proto.Encoding) (enc proto.Encoding, rxbuf *fwd.Reader, err error) { func _handshakeClient(ctx context.Context, conn net.Conn, version uint32, encPrefer proto.Encoding) (enc proto.Encoding, rxbuf *xbufReader, err error) {
defer func() { defer func() {
if err != nil { if err != nil {
err = &_HandshakeError{_LinkClient, conn.LocalAddr(), conn.RemoteAddr(), err} err = &_HandshakeError{_LinkClient, conn.LocalAddr(), conn.RemoteAddr(), err}
} }
}() }()
rxbuf = fwd.NewReader(conn) rxbuf = newXBufReader(conn, /*any non-small*/1024)
var peerEnc proto.Encoding var peerEnc proto.Encoding
err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error { err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error {
...@@ -144,14 +143,14 @@ func _handshakeClient(ctx context.Context, conn net.Conn, version uint32, encPre ...@@ -144,14 +143,14 @@ func _handshakeClient(ctx context.Context, conn net.Conn, version uint32, encPre
return peerEnc, rxbuf, nil return peerEnc, rxbuf, nil
} }
func _handshakeServer(ctx context.Context, conn net.Conn, version uint32) (enc proto.Encoding, rxbuf *fwd.Reader, err error) { func _handshakeServer(ctx context.Context, conn net.Conn, version uint32) (enc proto.Encoding, rxbuf *xbufReader, err error) {
defer func() { defer func() {
if err != nil { if err != nil {
err = &_HandshakeError{_LinkServer, conn.LocalAddr(), conn.RemoteAddr(), err} err = &_HandshakeError{_LinkServer, conn.LocalAddr(), conn.RemoteAddr(), err}
} }
}() }()
rxbuf = fwd.NewReader(conn) rxbuf = newXBufReader(conn, /*any non-small*/1024)
var peerEnc proto.Encoding var peerEnc proto.Encoding
err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error { err = xcontext.WithCloseOnErrCancel(ctx, conn, func() error {
...@@ -217,7 +216,7 @@ func txHello(errctx string, conn net.Conn, version uint32, enc proto.Encoding) ( ...@@ -217,7 +216,7 @@ func txHello(errctx string, conn net.Conn, version uint32, enc proto.Encoding) (
return nil return nil
} }
func rxHello(errctx string, rx *fwd.Reader) (enc proto.Encoding, version uint32, err error) { func rxHello(errctx string, rx *xbufReader) (enc proto.Encoding, version uint32, err error) {
defer xerr.Context(&err, errctx) defer xerr.Context(&err, errctx)
b := make([]byte, 4) b := make([]byte, 4)
...@@ -248,7 +247,7 @@ func rxHello(errctx string, rx *fwd.Reader) (enc proto.Encoding, version uint32, ...@@ -248,7 +247,7 @@ func rxHello(errctx string, rx *fwd.Reader) (enc proto.Encoding, version uint32,
} }
peerEnc = 'M' peerEnc = 'M'
rxM := msgp.Reader{R: rx} rxM := msgp.Reader{R: rx.Reader}
peerVer, err = rxM.ReadUint32() peerVer, err = rxM.ReadUint32()
if err != nil { if err != nil {
return 0, 0, fmt.Errorf("M: recv peer version: %s", err) // XXX + "read magic" ctx return 0, 0, fmt.Errorf("M: recv peer version: %s", err) // XXX + "read magic" ctx
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment