// Copyright (C) 2016-2021  Nexedi SA and Contributors.
//                          Kirill Smelkov <kirr@nexedi.com>
//
// This program is free software: you can Use, Study, Modify and Redistribute
// it under the terms of the GNU General Public License version 3, or (at your
// option) any later version, as published by the Free Software Foundation.
//
// You can also Link and Combine this program with other software covered by
// the terms of any of the Free Software licenses or any of the Open Source
// Initiative approved licenses and Convey the resulting work. Corresponding
// source of such a combination shall include the source code for all other
// software used.
//
// This program is distributed WITHOUT ANY WARRANTY; without even the implied
// warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
//
// See COPYING file for full licensing terms.
// See https://www.nexedi.com/licensing for rationale and options.

package neo
// storage node

import (
	"context"
	"errors"
	"fmt"
	"sync"

	"lab.nexedi.com/kirr/go123/xcontext"
	"lab.nexedi.com/kirr/go123/xerr"
	"lab.nexedi.com/kirr/go123/xnet"
	"lab.nexedi.com/kirr/go123/xsync"

	"lab.nexedi.com/kirr/neo/go/neo/neonet"
	"lab.nexedi.com/kirr/neo/go/neo/proto"
	"lab.nexedi.com/kirr/neo/go/neo/storage"
	"lab.nexedi.com/kirr/neo/go/neo/xneo"
	"lab.nexedi.com/kirr/neo/go/zodb"
	"lab.nexedi.com/kirr/neo/go/internal/log"
	"lab.nexedi.com/kirr/neo/go/internal/task"
	xxcontext "lab.nexedi.com/kirr/neo/go/internal/xcontext"
	taskctx   "lab.nexedi.com/kirr/neo/go/internal/xcontext/task"
	"lab.nexedi.com/kirr/neo/go/internal/xio"
	"lab.nexedi.com/kirr/neo/go/internal/xzodb"
)

// Storage is NEO node that keeps data and provides read/write access to it via network.
//
// Storage implements only NEO protocol logic with data being persisted via provided storage.Backend.
type Storage struct {
	node *_MasteredNode

	lli  xneo.Listener
	back storage.Backend

	// whole Run runs under runCtx
	runCtx context.Context
}

// NewStorage creates new storage node that will talk to master on masterAddr.
//
// The storage uses back as underlying backend for storing data.
// Use Run to actually start running the node.
func NewStorage(clusterName, masterAddr string, net xnet.Networker, back storage.Backend) *Storage {
	return &Storage{
		node: newMasteredNode(proto.STORAGE, clusterName, net, masterAddr),
		back: back,
	}
}


// Run starts storage node and runs it until either ctx is cancelled or master
// commands it to shutdown.
//
// The storage will be serving incoming connections on l.
func (stor *Storage) Run(ctx context.Context, l xnet.Listener) (err error) {
	addr := l.Addr()
	log.Infof(ctx, "%s: listening on %s ...", stor.node.MyInfo.NID, addr)
	stor.runCtx = ctx

	// update our serving address in node
	naddr, err := proto.Addr(addr)
	if err != nil {
		return fmt.Errorf("run @%s: %s", addr, err)
	}
	stor.node.MyInfo.Addr = naddr

	// wrap listener with link / identificaton hello checker
	stor.lli = xneo.NewListener(neonet.NewLinkListener(l))
	defer func() {
		__ := stor.lli.Close()
		err = xerr.First(err, __)
	}()

	defer func() {
		__ := stor.back.Close()
		err = xerr.First(err, __)
	}()

	// connect to master and let it drive us via commands and updates
	return stor.node.TalkMaster(ctx, func(ctx context.Context, mlink *_MasterLink) error {
		// XXX move -> SetNumReplicas handler
		// // NumReplicas: neo/py meaning for n(replica) = `n(real-replica) - 1`
		// if !(accept.NumPartitions == 1 && accept.NumReplicas == 0) {
		// 	return fmt.Errorf("TODO for 1-storage POC: Npt: %d  Nreplica: %d", accept.NumPartitions, accept.NumReplicas)
		// }

		// let master initialize us. If successful this ends with StartOperation command.
		reqStart, err := stor.m1initialize(ctx, mlink)
		if err != nil {
			return err
		}

		// we got StartOperation command. Let master drive us during service phase.
		return stor.m1serve(ctx, mlink, reqStart)
	})
}

// m1initialize drives storage by master messages during initialization phase
//
// Initialization includes master retrieving info for cluster recovery and data
// verification before starting operation. Initialization finishes either
// successfully with receiving master command to start operation, or
// unsuccessfully with connection closing indicating initialization was
// cancelled or some other error.
//
// return error indicates:
// - nil:  initialization was ok and a command came from master to start operation.
// - !nil: initialization was cancelled or failed somehow.
func (stor *Storage) m1initialize(ctx context.Context, mlink *_MasterLink) (reqStart *neonet.Request, err error) {
	defer task.Runningf(&ctx, "mserve init")(&err)

	for {
		req, err := mlink.Recv1(ctx)
		if err != nil {
			return nil, err
		}
		err = stor.m1initialize1(ctx, req)
		if err == cmdStart {
			// start - transition to serve
			return &req, nil
		}
		req.Close()
		if err != nil {
			return nil, err
		}
	}
}

var cmdStart = errors.New("start requested")

// m1initialize1 handles one message from master from under m1initialize
func (stor *Storage) m1initialize1(ctx context.Context, req neonet.Request) error {
	var err error

	switch msg := req.Msg.(type) {
	default:
		return fmt.Errorf("unexpected message: %T", msg)

	case *proto.StartOperation:
		// ok, transition to serve
		return cmdStart

	case *proto.Recovery:
		err = req.Reply(&proto.AnswerRecovery{
			PTid:		stor.node.State.PartTab.PTid,
			BackupTid:	proto.INVALID_TID,
			TruncateTid:	proto.INVALID_TID})

	case *proto.AskPartitionTable:
		// TODO initially read PT from disk
		err = req.Reply(&proto.AnswerPartitionTable{
			PTid:	     stor.node.State.PartTab.PTid,
			NumReplicas: 0, // FIXME hardcoded; NEO/py uses this as n(replica)-1
			RowList:     stor.node.State.PartTab.Dump()})

	case *proto.LockedTransactions:
		// XXX r/o stub
		err = req.Reply(&proto.AnswerLockedTransactions{})

	// TODO AskUnfinishedTransactions

	case *proto.LastIDs:
		lastTid, zerr1 := stor.back.LastTid(ctx)
		lastOid, zerr2 := stor.back.LastOid(ctx)
		if zerr := xerr.First(zerr1, zerr2); zerr != nil {
			return zerr	// TODO send the error to M ?
		}

		err = req.Reply(&proto.AnswerLastIDs{LastTid: lastTid, LastOid: lastOid})

	case *proto.SendPartitionTable:
		// TODO M sends us whole PT -> save locally

	case *proto.NotifyPartitionChanges:
		// TODO M sends us δPT -> save locally
	}

	return err
}

// m1serve drives storage by master messages during service phase.
//
// Service is regular phase serving requests from clients to load/save objects,
// handling transaction commit (with master) and syncing data with other
// storage nodes.
//
// it always returns with an error describing why serve had to be stopped -
// either due to master commanding us to stop, or context cancel or some other
// error.
func (stor *Storage) m1serve(ctx context.Context, mlink *_MasterLink, reqStart *neonet.Request) (err error) {
	defer task.Runningf(&ctx, "mserve")(&err)

	// serve clients while operational
	serveCtx := taskctx.Runningf(stor.runCtx, "%s", stor.node.MyInfo.NID)
	serveCtx, serveCancel := xcontext.Merge/*Cancel*/(serveCtx, ctx)
	wg := sync.WaitGroup{}
	wg.Add(1)
	go func() {
		defer wg.Done()
		stor.serve(serveCtx)
	}()
	defer wg.Wait()
	defer serveCancel()

	// reply M we are ready
	// NOTE NEO/py sends NotifyReady on another conn; we patched py: see
	// https://lab.nexedi.com/kirr/neo/commit/4eaaf186 for context.
	err = reqStart.Reply(&proto.NotifyReady{})
	reqStart.Close()
	if err != nil {
		return err
	}

	for {
		req, err := mlink.Recv1(ctx)
		if err != nil {
			return err
		}
		err = stor.m1serve1(ctx, req)
		req.Close()
		if err != nil {
			return err
		}
	}
}

// m1serve1 handles one message from master under m1serve
func (stor *Storage) m1serve1(ctx context.Context, req neonet.Request) error {
	switch msg := req.Msg.(type) {
	default:
		return fmt.Errorf("unexpected message: %T", msg)

	case *proto.StopOperation:
		return fmt.Errorf("stop requested")

	// TODO commit related messages
	}

	return nil
}

// --- serve incoming connections from other nodes ---

func (stor *Storage) serve(ctx context.Context) (err error) {
	defer task.Runningf(&ctx, "serve")(&err)

	wg := sync.WaitGroup{}
	defer wg.Wait()

	// XXX ? -> _MasteredNode.Accept(lli) (it will verify IdTime against .nodeTab[nid])
	// XXX ? -> Node.Serve(lli -> func(idReq))
	for {
		if ctx.Err() != nil {
			return ctx.Err()
		}

		req, idReq, err := stor.lli.Accept(ctx)
		if err != nil {
			if !xxcontext.Canceled(err) {
				log.Error(ctx, err)	// XXX throttle?
			}
			continue
		}

		wg.Add(1)
		go func() {
			defer wg.Done()
			err := xio.WithCloseOnRetCancel(ctx, req.Link(), func() error {
				return stor.serveLink(ctx, req, idReq)
			})
			if err == nil {
				if ctx.Err() == nil {
					// the error is not due to serve cancel
					log.Error(ctx, err)
				}
			}
		}()
	}
}

// identify processes identification request from connected peer.
func (stor *Storage) identify(ctx context.Context, idReq *proto.RequestIdentification) (idResp proto.Msg, err error) {
	accept, reject := stor.identify_(idReq)
	if accept != nil {
		log.Info(ctx, "accepting identification")
		idResp = accept
	} else {
		log.Info(ctx, "rejecting identification (%s)", reject.Message)
		idResp = reject
	}
	var ereject error
	if reject != nil {
		ereject = reject
	}
	return idResp, ereject
}
func (stor *Storage) identify_(idReq *proto.RequestIdentification) (proto.Msg, *proto.Error) {
	// XXX stub: we accept clients and don't care about their NID/IDtime
	if idReq.NodeType != proto.CLIENT {
		return nil, &proto.Error{proto.PROTOCOL_ERROR, "only clients are accepted"}
	}
	if idReq.ClusterName != stor.node.ClusterName {
		return nil, &proto.Error{proto.PROTOCOL_ERROR, "cluster name mismatch"}
	}

	return &proto.AcceptIdentification{
		NodeType: stor.node.MyInfo.Type,
		MyNID:    stor.node.MyInfo.NID,		// XXX lock wrt update
		YourNID:  idReq.NID,
	}, nil
}


// serveLink serves incoming node-node link connection.
func (stor *Storage) serveLink(ctx context.Context, req *neonet.Request, idReq *proto.RequestIdentification) (err error) {
	link := req.Link()
	defer task.Runningf(&ctx, "serve %s", idReq.NID)(&err)

	// first process identification
	// TODO -> .Accept() that would listen, handshake and identify a node
	// -> and only then return it to serve loop if identified ok.
	idResp, err := stor.identify(ctx, idReq)
	err2 := req.Reply(idResp)
	if err == nil {
		err = err2
	}
	if err != nil {
		return err
	}

	// client passed identification, now serve other requests
	wg := xsync.NewWorkGroup(ctx)
	for {
		req, err := link.Recv1()
		if err != nil {
			return err
		}

		// FIXME this go + link.Recv1() in serveClient arrange for N(goroutine) ↑
		// with O(1/nreq) rate (i.e. N(goroutine, nreq) ~ ln(nreq)).
		//
		// TODO -> do what go-fuse does:
		// - serve request in the goroutine that received it (reduces latency)
		// - spawn another goroutine to continue accept loop
		// - limit number of such accept-loop goroutines by GOMAXPROC
		wg.Go(func(ctx context.Context) error {
			return stor.serveClient(ctx, req)
		})
	}

	err = wg.Wait()
	return err
}

// serveClient serves incoming client request.
func (stor *Storage) serveClient(ctx context.Context, req neonet.Request) error {
	link := req.Link()

	for {
		resp := stor.serveClient1(ctx, req.Msg)
		err := req.Reply(resp)
		req.Close()
		if err != nil {
			return err
		}

		// XXX hack -> TODO resp.Release()
		// XXX req.Msg release too?
		if resp, ok := resp.(*proto.AnswerObject); ok {
			resp.Data.Release()
		}

		// keep on going in the same goroutine to avoid goroutine creation overhead
		// TODO += timeout -> go away if inactive
		req, err = link.Recv1()
		if err != nil {
			return err
		}
	}
}

// serveClient1 prepares response for 1 request from client
func (stor *Storage) serveClient1(ctx context.Context, req proto.Msg) (resp proto.Msg) {
	switch req := req.(type) {
	case *proto.GetObject:
		xid := zodb.Xid{Oid: req.Oid}
		if req.At != proto.INVALID_TID {
			xid.At = req.At
		} else {
			xid.At = xzodb.Before2At(req.Before)
		}

		obj, err := stor.back.Load(ctx, xid)
		if err != nil {
			// translate err to NEO protocol error codes
			return proto.ZODBErrEncode(err)
		}

		// compatibility with py side:
		// for loadSerial - check we have exact hit - else "nodata"
		if req.At != proto.INVALID_TID {
		        if obj.Serial != req.At {
				return &proto.Error{
					Code:    proto.OID_NOT_FOUND,
					Message: fmt.Sprintf("%s: no data with serial %s", xid.Oid, req.At),
				}
		        }
		}

		return obj

	case *proto.LastTransaction:
		lastTid, err := stor.back.LastTid(ctx)
		if err != nil {
			return proto.ZODBErrEncode(err)
		}

		return &proto.AnswerLastTransaction{lastTid}

	// TODO case *ObjectHistory:
	// TODO case *StoreObject:
	//      ...

	default:
		return &proto.Error{proto.PROTOCOL_ERROR, fmt.Sprintf("unexpected message %T", req)}
	}
}