Commit 8d0a1469 authored by Kirill Smelkov's avatar Kirill Smelkov

X Handshake draftly done

parent c5278b55
...@@ -26,6 +26,7 @@ import ( ...@@ -26,6 +26,7 @@ import (
"sync" "sync"
"sync/atomic" "sync/atomic"
"encoding/binary"
"fmt" "fmt"
) )
...@@ -119,6 +120,7 @@ const ( ...@@ -119,6 +120,7 @@ const (
// even (server) and its peer as odd (client). // even (server) and its peer as odd (client).
// //
// 2. NodeLink.Accept() works only on server side. // 2. NodeLink.Accept() works only on server side.
// XXX vs client processing e.g. invalidation notifications from master ?
// //
// Usually server role should be used for connections created via // Usually server role should be used for connections created via
// net.Listen/net.Accept and client role for connections created via net.Dial. // net.Listen/net.Accept and client role for connections created via net.Dial.
...@@ -547,17 +549,94 @@ func (nl *NodeLink) recvPkt() (*PktBuf, error) { ...@@ -547,17 +549,94 @@ func (nl *NodeLink) recvPkt() (*PktBuf, error) {
} }
// ---- Handshake ----
// Handshake performs NEO protocol handshake just after 2 nodes are connected
func Handshake(conn net.Conn) error {
return handshake(conn, PROTOCOL_VERSION)
}
func handshake(conn net.Conn, version uint32) error {
errch := make(chan error, 2)
go func() {
var b [4]byte
binary.BigEndian.PutUint32(b[:], version) // XXX -> hton32 ?
_, err := conn.Write(b[:])
// XXX EOF -> ErrUnexpectedEOF ?
errch <- err
}()
go func() {
var b [4]byte
_, err := io.ReadFull(conn, b[:])
if err == io.EOF {
err = io.ErrUnexpectedEOF // can be returned with n = 0
}
if err == nil {
peerVersion := binary.BigEndian.Uint32(b[:]) // XXX -> ntoh32 ?
if peerVersion != version {
err = fmt.Errorf("protocol version mismatch: peer = %08x ; our side = %08x", peerVersion, version)
}
}
errch <- err
}()
for i := 0; i < 2; i++ {
err := <-errch
if err != nil {
return &HandshakeError{conn.LocalAddr(), conn.RemoteAddr(), err}
}
}
return nil
}
type HandshakeError struct {
// XXX just keep .Conn? (but .Conn can be closed)
LocalAddr net.Addr
RemoteAddr net.Addr
Err error
}
func (e *HandshakeError) Error() string {
return fmt.Sprintf("%s - %s: handshake: %s", e.LocalAddr, e.RemoteAddr, e.Err.Error())
}
// ---- for convenience: Dial/Listen ---- // ---- for convenience: Dial/Listen ----
// Dial connects to address on named network and wrap the connection as NodeLink // Dial connects to address on named network and wrap the connection as NodeLink
// TODO +tls.Config // TODO +tls.Config
func Dial(ctx context.Context, network, address string) (*NodeLink, error) { func Dial(ctx context.Context, network, address string) (nl *NodeLink, err error) {
d := net.Dialer{} d := net.Dialer{}
peerConn, err := d.DialContext(ctx, network, address) peerConn, err := d.DialContext(ctx, network, address)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// do the handshake. don't forget to close peerConn if we return with an error
defer func() {
if err != nil {
peerConn.Close()
}
}()
errch := make(chan error)
go func() {
errch <- Handshake(peerConn)
}()
select {
case <-ctx.Done():
return nil, ctx.Err()
case err = <-errch:
if err != nil {
return nil, &HandshakeError{peerConn.LocalAddr(), peerConn.RemoteAddr(), err}
}
}
// handshake ok -> NodeLink ready
return NewNodeLink(peerConn, LinkClient), nil return NewNodeLink(peerConn, LinkClient), nil
} }
...@@ -571,6 +650,13 @@ func (l *Listener) Accept() (*NodeLink, error) { ...@@ -571,6 +650,13 @@ func (l *Listener) Accept() (*NodeLink, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
err = Handshake(peerConn)
if err != nil {
peerConn.Close()
return nil, err
}
return NewNodeLink(peerConn, LinkServer), nil return NewNodeLink(peerConn, LinkServer), nil
} }
......
...@@ -99,6 +99,11 @@ func xwait(w interface { Wait() error }) { ...@@ -99,6 +99,11 @@ func xwait(w interface { Wait() error }) {
exc.Raiseif(err) exc.Raiseif(err)
} }
func xhandshake(c net.Conn, version uint32) {
err := handshake(c, version)
exc.Raiseif(err)
}
// Prepare PktBuf with content // Prepare PktBuf with content
func _mkpkt(connid uint32, msgcode uint16, payload []byte) *PktBuf { func _mkpkt(connid uint32, msgcode uint16, payload []byte) *PktBuf {
pkt := &PktBuf{make([]byte, PktHeadLen + len(payload))} pkt := &PktBuf{make([]byte, PktHeadLen + len(payload))}
...@@ -530,3 +535,62 @@ func TestNodeLink(t *testing.T) { ...@@ -530,3 +535,62 @@ func TestNodeLink(t *testing.T) {
xclose(nl1) xclose(nl1)
xclose(nl2) xclose(nl2)
} }
func TestHandshake(t *testing.T) {
// handshake ok
p1, p2 := net.Pipe()
wg := WorkGroup()
wg.Gox(func() {
xhandshake(p1, 1)
})
wg.Gox(func() {
xhandshake(p2, 1)
})
xwait(wg)
xclose(p1)
xclose(p2)
// version mismatch
p1, p2 = net.Pipe()
var err1, err2 error
wg = WorkGroup()
wg.Gox(func() {
err1 = handshake(p1, 1)
})
wg.Gox(func() {
err2 = handshake(p2, 2)
})
xwait(wg)
xclose(p1)
xclose(p2)
err1Want := "pipe - pipe: handshake: protocol version mismatch: peer = 00000002 ; our side = 00000001"
err2Want := "pipe - pipe: handshake: protocol version mismatch: peer = 00000001 ; our side = 00000002"
if !(err1 != nil && err1.Error() == err1Want) {
t.Errorf("handshake ver mismatch: p1: unexpected error:\nhave: %v\nwant: %v", err1, err1Want)
}
if !(err2 != nil && err2.Error() == err2Want) {
t.Errorf("handshake ver mismatch: p2: unexpected error:\nhave: %v\nwant: %v", err2, err2Want)
}
// tx & rx problem
p1, p2 = net.Pipe()
err1, err2 = nil, nil
wg = WorkGroup()
wg.Gox(func() {
err1 = handshake(p1, 1)
})
wg.Gox(func() {
xclose(p2)
})
xwait(wg)
xclose(p1)
err11, ok := err1.(*HandshakeError)
if !ok || !(err11.Err == io.ErrClosedPipe /* on Write */ || err11.Err == io.ErrUnexpectedEOF /* on Read */) {
t.Errorf("handshake peer close: unexpected error: %#v", err1)
}
}
...@@ -235,9 +235,9 @@ func TestPktMarshal(t *testing.T) { ...@@ -235,9 +235,9 @@ func TestPktMarshal(t *testing.T) {
}, },
// uint32, Address, string, float64 // uint32, Address, string, float64
{&RequestIdentification{8, CLIENT, 17, Address{"localhost", 7777}, "myname", 0.12345678}, {&RequestIdentification{CLIENT, 17, Address{"localhost", 7777}, "myname", 0.12345678},
u32(8) + u32(2) + u32(17) + u32(9) + u32(2) + u32(17) + u32(9) +
"localhost" + u16(7777) + "localhost" + u16(7777) +
u32(6) + "myname" + u32(6) + "myname" +
hex("3fbf9add1091c895"), hex("3fbf9add1091c895"),
......
...@@ -110,11 +110,6 @@ func IdentifyPeer(link *NodeLink, myNodeType NodeType) (nodeInfo RequestIdentifi ...@@ -110,11 +110,6 @@ func IdentifyPeer(link *NodeLink, myNodeType NodeType) (nodeInfo RequestIdentifi
// XXX also handle Error // XXX also handle Error
case *RequestIdentification: case *RequestIdentification:
if pkt.ProtocolVersion != PROTOCOL_VERSION {
// TODO also tell peer with Error
return nodeInfo, fmt.Errorf("protocol version mismatch: peer = %d ; our side = %d", pkt.ProtocolVersion, PROTOCOL_VERSION)
}
// TODO (.NodeType, .UUID, .Address, .Name, .IdTimestamp) -> check + register to NM // TODO (.NodeType, .UUID, .Address, .Name, .IdTimestamp) -> check + register to NM
err = EncodeAndSend(conn, &AcceptIdentification{ err = EncodeAndSend(conn, &AcceptIdentification{
...@@ -123,8 +118,6 @@ func IdentifyPeer(link *NodeLink, myNodeType NodeType) (nodeInfo RequestIdentifi ...@@ -123,8 +118,6 @@ func IdentifyPeer(link *NodeLink, myNodeType NodeType) (nodeInfo RequestIdentifi
NumPartitions: 0, // XXX NumPartitions: 0, // XXX
NumReplicas: 0, // XXX NumReplicas: 0, // XXX
YourNodeID: pkt.NodeID, YourNodeID: pkt.NodeID,
Primary: Address{}, // XXX
//KnownMasterList: // XXX
}) })
if err != nil { if err != nil {
...@@ -152,7 +145,6 @@ func IdentifyMe(link *NodeLink, nodeType NodeType /*XXX*/) (peerType NodeType, e ...@@ -152,7 +145,6 @@ func IdentifyMe(link *NodeLink, nodeType NodeType /*XXX*/) (peerType NodeType, e
}() }()
err = EncodeAndSend(conn, &RequestIdentification{ err = EncodeAndSend(conn, &RequestIdentification{
ProtocolVersion: PROTOCOL_VERSION,
NodeType: nodeType, NodeType: nodeType,
NodeID: 0, // XXX NodeID: 0, // XXX
Address: Address{}, // XXX Address: Address{}, // XXX
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment