Commit 3354b401 authored by Kirill Smelkov's avatar Kirill Smelkov

xnet: Adjust Networker.Listen to return listener that can handle cancellation in Accept

We already handle cancellation in Dial, but Accepting was out of luck
until now. This makes it more difficult for clients to implement and
wrap acceptors where they need to handle cancellations. This also makes
it possible for a test or program to get stuck in Accept loop if it is
not careful enough to manually handle ctx cancel around Accept calls.

-> Fix it in one place - here, in xnet - so that users are offloaded
from all this and can just call Accept(ctx) and rely on underlying
implementation to handle ctx cancel.

This patch:

- introduces xnet.Listener interface, which is like net.Listener, but
  Accept goes with ctx argument.
- changes Networker.Listen signature to return xnet.Listener instead of
  net.Listener. While we are here - changing it - also add ctx argument
  to Listen call itself.
- Adds listenerCtx - which, given net.Listener, provides xnet.Listener
  by wrapping some logic around original.
- Adapts NetPlain, NetTLS and NetTrace to provide updated interface.

We'll fix up everything in other packages to match/use updated interface
in the next patch.
parent 9cfe4f52
// Copyright (C) 2017-2019 Nexedi SA and Contributors.
// Copyright (C) 2017-2020 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
......@@ -27,6 +27,8 @@ import (
"os"
"crypto/tls"
"lab.nexedi.com/kirr/go123/xsync"
)
// Networker is interface representing access-point to a streaming network.
......@@ -48,9 +50,16 @@ type Networker interface {
// Listen starts listening on local address laddr on underlying network access-point.
//
// See net.Listen for semantic details.
//
// XXX also introduce xnet.Listener in which Accept() accepts also ctx?
Listen(laddr string) (net.Listener, error)
Listen(ctx context.Context, laddr string) (Listener, error)
}
// Listener amends net.Listener for Accept to handle cancellation.
type Listener interface {
Accept(ctx context.Context) (net.Conn, error)
// same as in net.Listener
Close() error
Addr() net.Addr
}
......@@ -87,8 +96,13 @@ func (n *netPlain) Dial(ctx context.Context, addr string) (net.Conn, error) {
return d.DialContext(ctx, n.network, addr)
}
func (n *netPlain) Listen(laddr string) (net.Listener, error) {
return net.Listen(n.network, laddr)
func (n *netPlain) Listen(ctx context.Context, laddr string) (Listener, error) {
lc := net.ListenConfig{}
rawl, err := lc.Listen(ctx, n.network, laddr)
if err != nil {
return nil, err
}
return newListenerCtx(rawl), nil
}
// NetTLS wraps underlying networker with TLS layer according to config.
......@@ -122,10 +136,121 @@ func (n *netTLS) Dial(ctx context.Context, addr string) (net.Conn, error) {
return tls.Client(c, n.config), nil
}
func (n *netTLS) Listen(laddr string) (net.Listener, error) {
l, err := n.inner.Listen(laddr)
func (n *netTLS) Listen(ctx context.Context, laddr string) (Listener, error) {
l, err := n.inner.Listen(ctx, laddr)
if err != nil {
return nil, err
}
return &listenerTLS{l, n}, nil
}
// listenerTLS implements Listener for netTLS.
type listenerTLS struct {
innerl Listener
net *netTLS
}
func (l *listenerTLS) Close() error {
return l.innerl.Close()
}
func (l *listenerTLS) Addr() net.Addr {
return l.innerl.Addr()
}
func (l *listenerTLS) Accept(ctx context.Context) (net.Conn, error) {
conn, err := l.innerl.Accept(ctx)
if err != nil {
return nil, err
}
return tls.NewListener(l, n.config), nil
return tls.Server(conn, l.net.config), nil
}
// ----------------------------------------
// listenerCtx provides Listener given net.Listener.
type listenerCtx struct {
rawl net.Listener // underlying listener
serveWG *xsync.WorkGroup // Accept loop is run under serveWG
serveCancel func() // Close calls serveCancel to request Accept loop shutdown
acceptq chan accepted // Accept results go -> acceptq
}
// accepted represents Accept result.
type accepted struct {
conn net.Conn
err error
}
func newListenerCtx(rawl net.Listener) *listenerCtx {
l := &listenerCtx{rawl: rawl, acceptq: make(chan accepted)}
ctx, cancel := context.WithCancel(context.Background())
l.serveWG = xsync.NewWorkGroup(ctx)
l.serveCancel = cancel
l.serveWG.Go(l.serve)
return l
}
func (l *listenerCtx) serve(ctx context.Context) error {
for {
// raw Accept. This should not stuck overliving ctx as Close closes rawl
conn, err := l.rawl.Accept()
// send result to Accept, but don't try to send if we are closed
ctxErr := ctx.Err()
if ctxErr == nil {
select {
case <-ctx.Done():
// closed
ctxErr = ctx.Err()
case l.acceptq <- accepted{conn, err}:
// ok
}
}
// shutdown if we are closed
if ctxErr != nil {
if conn != nil {
conn.Close() // ignore err
}
return ctxErr
}
}
}
func (l *listenerCtx) Close() error {
l.serveCancel()
err := l.rawl.Close()
_ = l.serveWG.Wait() // ignore err - it is always "canceled"
return err
}
func (l *listenerCtx) Accept(ctx context.Context) (_ net.Conn, err error) {
err = ctx.Err()
// don't try to pull from acceptq if ctx is already canceled
if err == nil {
select {
case <-ctx.Done():
err = ctx.Err()
case a := <-l.acceptq:
return a.conn, a.err
}
}
// here it is always due to ctx cancel
laddr := l.rawl.Addr()
return nil, &net.OpError{
Op: "accept",
Net: laddr.Network(),
Source: nil,
Addr: laddr,
Err: err,
}
}
func (l *listenerCtx) Addr() net.Addr {
return l.rawl.Addr()
}
// Copyright (C) 2017-2019 Nexedi SA and Contributors.
// Copyright (C) 2017-2020 Nexedi SA and Contributors.
// Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
......@@ -104,9 +104,9 @@ func (nt *netTrace) Dial(ctx context.Context, addr string) (net.Conn, error) {
return &traceConn{nt, c}, nil
}
func (nt *netTrace) Listen(laddr string) (net.Listener, error) {
func (nt *netTrace) Listen(ctx context.Context, laddr string) (Listener, error) {
// XXX +TraceNetListenPre ?
l, err := nt.inner.Listen(laddr)
l, err := nt.inner.Listen(ctx, laddr)
if err != nil {
return nil, err
}
......@@ -117,11 +117,11 @@ func (nt *netTrace) Listen(laddr string) (net.Listener, error) {
// netTraceListener wraps net.Listener to wrap accepted connections with traceConn.
type netTraceListener struct {
nt *netTrace
net.Listener
Listener
}
func (ntl *netTraceListener) Accept() (net.Conn, error) {
c, err := ntl.Listener.Accept()
func (ntl *netTraceListener) Accept(ctx context.Context) (net.Conn, error) {
c, err := ntl.Listener.Accept(ctx)
if err != nil {
return nil, err
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment