# -*- coding: utf-8 -*-
#
# Copyright (C) 2012-2017  Nexedi SA
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import random, weakref
from bisect import bisect
from collections import defaultdict
from neo.lib import logging
from neo.lib.bootstrap import BootstrapManager
from neo.lib.exception import PrimaryFailure
from neo.lib.handler import EventHandler
from neo.lib.node import NodeManager
from neo.lib.protocol import ClusterStates, \
    NodeTypes, Packets, uuid_str, ZERO_TID
from neo.lib.util import add64, dump
from .app import StateChangedException
from .pt import PartitionTable
from .handlers.backup import BackupHandler

"""
Backup algorithm

This implementation relies on normal storage replication.
Storage nodes that are specialised for backup are not in the same NEO cluster,
but are managed by another master in a different cluster.

When the cluster is in BACKINGUP state, its master acts like a client to the
master of the main cluster. It gets notified of new data thanks to invalidation,
and notifies in turn its storage nodes what/when to replicate.

Storages stay in UP_TO_DATE state, even if partitions are synchronized up to
different tids. Storage nodes remember they are in such state and when
switching into RUNNING state, the cluster cuts the DB at the "backup TID", which
is the last TID for which we have all data. This TID can't be guessed from
'trans' and 'obj' tables, like it is done in normal mode, so:
- The master must even notify storages of transactions that don't modify their
  partitions: see Replicate packets without any source.
- 'backup_tid' properties exist in many places, on the master and the storages,
  so that the DB can be made consistent again at any moment, without losing
  any (or little) data.

Out of backup storage nodes assigned to a partition, one is chosen as primary
for that partition. It means only this node will fetch data from the upstream
cluster, to minimize bandwidth between clusters. Other replicas will
synchronize from the primary node.

There is no conflict of node id between the 2 clusters:
- Storage nodes connect anonymously to upstream.
- The master node gets an id from the upstream master and uses it only when
  communicating with it.
"""

class BackupApplication(object):

    pt = None
    uuid = None

    def __init__(self, app, name, master_addresses):
        self.app = weakref.proxy(app)
        self.name = name
        self.nm = NodeManager()
        for master_address in master_addresses:
            self.nm.createMaster(address=master_address)

    em = property(lambda self: self.app.em)
    ssl = property(lambda self: self.app.ssl)

    def close(self):
        self.nm.close()
        del self.__dict__

    def log(self):
        self.nm.log()
        if self.pt is not None:
            self.pt.log()

    def provideService(self):
        logging.info('provide backup')
        poll = self.em.poll
        app = self.app
        pt = app.pt
        while True:
            app.changeClusterState(ClusterStates.STARTING_BACKUP)
            bootstrap = BootstrapManager(self, NodeTypes.CLIENT)
            # {offset -> node}      (primary storage for off which will be talking to upstream cluster)
            self.primary_partition_dict = {}
            # [[tid]]   part -> []tid↑  (currently scheduled-for-sync txns)
            self.tid_list = tuple([] for _ in xrange(pt.getPartitions()))
            try:
                while True:
                    for node in pt.getNodeSet(readable=True):
                        if not app.isStorageReady(node.getUUID()):
                            break
                    else:
                        break
                    poll(1)
                node, conn, num_partitions, num_replicas = \
                    bootstrap.getPrimaryConnection()
                try:
                    app.changeClusterState(ClusterStates.BACKINGUP)
                    del bootstrap, node
                    if num_partitions != pt.getPartitions():
                        raise RuntimeError("inconsistent number of partitions")
                    self.ignore_invalidations = True
                    self.pt = PartitionTable(num_partitions, num_replicas)
                    conn.setHandler(BackupHandler(self))
                    conn.ask(Packets.AskPartitionTable())
                    conn.ask(Packets.AskLastTransaction())
                    # debug variable to log how big 'tid_list' can be.
                    self.debug_tid_count = 0
                    while True:
                        poll(1)
                except PrimaryFailure, msg:
                    logging.error('upstream master is down: %s', msg)
                finally:
                    app.backup_tid = pt.getBackupTid()
                    try:
                        conn.close()
                    except PrimaryFailure:
                        pass
                    try:
                        del self.pt
                    except AttributeError:
                        pass
            except StateChangedException, e:
                if e.args[0] != ClusterStates.STOPPING_BACKUP:
                    raise
                app.changeClusterState(*e.args)
                tid = app.backup_tid
                # Wait for non-primary partitions to catch up,
                # so that all UP_TO_DATE cells are really UP_TO_DATE.
                # XXX: Another possibility could be to outdate such cells, and
                #      they would be quickly updated at the beginning of the
                #      RUNNING phase. This may simplify code.
                # Any unfinished replication from upstream will be truncated.
                while pt.getBackupTid(min) < tid:
                    poll(1)
                last_tid = app.getLastTransaction()
                handler = EventHandler(app)
                if tid < last_tid:
                    assert tid != ZERO_TID
                    logging.warning("Truncating at %s (last_tid was %s)",
                        dump(app.backup_tid), dump(last_tid))
                else:
                    # We will do a dummy truncation, just to leave backup mode,
                    # so it's fine to start automatically if there's any
                    # missing storage.
                    # XXX: Consider using another method to leave backup mode,
                    #      at least when there's nothing to truncate. Because
                    #      in case of StoppedOperation during VERIFYING state,
                    #      this flag will be wrongly set to False.
                    app._startup_allowed = True
                # If any error happened before reaching this line, we'd go back
                # to backup mode, which is the right mode to recover.
                del app.backup_tid
                # Now back to RECOVERY...
                return tid
            finally:
                del self.primary_partition_dict, self.tid_list
                pt.clearReplicating()

    def nodeLost(self, node):
        getCellList = self.app.pt.getCellList
        trigger_set = set()
        for offset, primary_node in self.primary_partition_dict.items():
            if primary_node is not node:
                continue
            cell_list = getCellList(offset, readable=True)
            cell = max(cell_list, key=lambda cell: cell.backup_tid)
            tid = cell.backup_tid
            self.primary_partition_dict[offset] = primary_node = cell.getNode()
            p = Packets.Replicate(tid, '', {offset: primary_node.getAddress()})
            for cell in cell_list:
                cell.replicating = tid
                if cell.backup_tid < tid:
                    logging.debug(
                        "ask %s to replicate partition %u up to %s from %s",
                        uuid_str(cell.getUUID()), offset,  dump(tid),
                        uuid_str(primary_node.getUUID()))
                    cell.getNode().send(p)
            trigger_set.add(primary_node)
        for node in trigger_set:
            self.triggerBackup(node)

    # NOTE called by backup_app.invalidateObjects() when it has info that
    # partitions in partition_set were updated in upstream cluster (up to `tid`)
    def invalidatePartitions(self, tid, prev_tid, partition_set):
        app = self.app
        app.setLastTransaction(tid)
        pt = app.pt
        trigger_set = set()
        untouched_dict = defaultdict(dict)
        for offset in xrange(pt.getPartitions()):
            try:
                last_max_tid = self.tid_list[offset][-1]
            except IndexError:
                last_max_tid = prev_tid
            if offset in partition_set:
                primary_list = []
                node_list = []
                cell_list = pt.getCellList(offset, readable=True)
                for cell in cell_list:
                    node = cell.getNode()
                    assert node.isConnected(), node
                    if cell.backup_tid == prev_tid:
                        if prev_tid == tid:
                            # Connecting to upstream: any node is that is
                            # up-to-date wrt upstream is candidate for being
                            # primary.
                            assert self.ignore_invalidations
                            if app.isStorageReady(node.getUUID()):
                                primary_list.append(node)
                            continue
                        # Let's given 4 TID t0,t1,t2,t3: if a cell is only
                        # modified by t0 & t3 and has all data for t0, 4 values
                        # are possible for its 'backup_tid' until it replicates
                        # up to t3: t0, t1, t2 or t3 - 1
                        # Choosing the smallest one (t0) is easier to implement
                        # but when leaving backup mode, we would always lose
                        # data if the last full transaction does not modify
                        # all partitions. t1 is wrong for the same reason.
                        # So we have chosen the highest one (t3 - 1).
                        # t2 should also work but maybe harder to implement.
                        cell.backup_tid = add64(tid, -1)
                        logging.debug(
                            "partition %u: updating backup_tid of %r to %s",
                            offset, cell, dump(cell.backup_tid))
                    else:
                        assert cell.backup_tid < last_max_tid, (
                            cell.backup_tid, last_max_tid, prev_tid, tid)
                    if app.isStorageReady(node.getUUID()):
                        node_list.append(node)
                # Make sure we have a primary storage for this partition.
                if offset not in self.primary_partition_dict:
                    self.primary_partition_dict[offset] = \
                        random.choice(primary_list or node_list)
                if node_list:
                    self.tid_list[offset].append(tid)   # XXX check tid is ↑
                    if primary_list:
                        # Resume replication to secondary cells.
                        self._triggerSecondary(
                            self.primary_partition_dict[offset],
                            offset, tid, cell_list)
                    else:
                        trigger_set.update(node_list)
            else:
                # Partition not touched, so increase 'backup_tid' of all
                # "up-to-date" replicas, without having to replicate.
                for cell in pt.getCellList(offset, readable=True):
                    if last_max_tid <= cell.backup_tid:
                        cell.backup_tid = tid
                        untouched_dict[cell.getNode()][offset] = None
                    elif last_max_tid <= cell.replicating:
                        # Same for 'replicating' to avoid useless orders.
                        logging.debug("silently update replicating order"
                            " of %s for partition %u, up to %s",
                            uuid_str(cell.getUUID()), offset,  dump(tid))
                        cell.replicating = tid
        for node, untouched_dict in untouched_dict.iteritems():
            if app.isStorageReady(node.getUUID()):
                node.send(Packets.Replicate(tid, '', untouched_dict))      # NOTE Mb -> Sb  (notify tid brings no new data)
        for node in trigger_set:
            self.triggerBackup(node)
        count = sum(map(len, self.tid_list))
        if self.debug_tid_count < count:
            logging.debug("Maximum number of tracked tids: %u", count)
            self.debug_tid_count = count

    def triggerBackup(self, node):
        tid_list = self.tid_list
        tid = self.app.getLastTransaction()
        replicate_list = []
        for offset, cell in self.app.pt.iterNodeCell(node):
            max_tid = tid_list[offset]
            if max_tid and self.primary_partition_dict[offset] is node and \
               max(cell.backup_tid, cell.replicating) < max_tid[-1]:
                cell.replicating = tid
                replicate_list.append(offset)
        if not replicate_list:
            return
        getCellList = self.pt.getCellList
        source_dict = {}
        address_set = set()
        for offset in replicate_list:
            cell_list = getCellList(offset, readable=True)
            random.shuffle(cell_list)
            assert cell_list, offset
            for cell in cell_list:
                addr = cell.getAddress()
                if addr in address_set:
                    break
            else:
                address_set.add(addr)
            source_dict[offset] = addr
            logging.debug("ask %s to replicate partition %u up to %s from %r",
                uuid_str(node.getUUID()), offset,  dump(tid), addr)
        node.send(Packets.Replicate(tid, self.name, source_dict))         # NOTE Mb -> Sb  (notify to trigger replicate up to tid)

    # NOTE feedback from Sb -> Mb a partition (requested by invalidatePartitions->triggerBackup) has been replicated
    def notifyReplicationDone(self, node, offset, tid):
        app = self.app
        cell = app.pt.getCell(offset, node.getUUID())
        tid_list = self.tid_list[offset]
        if tid_list: # may be empty if the cell is out-of-date      # XXX check how
                     # or if we're not fully initialized
            if tid < tid_list[0]:                                   # XXX why this is possible?
                cell.replicating = tid
            else:
                try:
                    tid = add64(tid_list[bisect(tid_list, tid)], -1)    # XXX why -1 ?
                except IndexError:
                    last_tid = app.getLastTransaction()
                    if tid < last_tid:
                        tid = last_tid
                        node.send(Packets.Replicate(tid, '', {offset: None}))
        logging.debug("partition %u: updating backup_tid of %r to %s",
                      offset, cell, dump(tid))
        cell.backup_tid = tid
        # TODO: Provide invalidation feedback about new txns to read-only
        #       clients connected to backup cluster. Not only here but also
        #       hooked to in-progress feedback from fetchObjects (storage).
        # Forget tids we won't need anymore.
        cell_list = app.pt.getCellList(offset, readable=True)
        del tid_list[:bisect(tid_list, min(x.backup_tid for x in cell_list))]   # XXX not only for primary Sb ?
        primary_node = self.primary_partition_dict.get(offset)
        primary = primary_node is node
        result = None if primary else app.pt.setUpToDate(node, offset)
        assert cell.isReadable()
        if result: # was out-of-date
            if primary_node is not None:
                max_tid, = [x.backup_tid for x in cell_list
                                         if x.getNode() is primary_node]
                if tid < max_tid:
                    cell.replicating = max_tid
                    logging.debug(
                        "ask %s to replicate partition %u up to %s from %s",
                        uuid_str(node.getUUID()), offset,  dump(max_tid),
                        uuid_str(primary_node.getUUID()))
                    node.send(Packets.Replicate(max_tid, '',
                        {offset: primary_node.getAddress()}))
        else:
            if app.getClusterState() == ClusterStates.BACKINGUP:
                self.triggerBackup(node)
            if primary:
                self._triggerSecondary(node, offset, tid, cell_list)
        return result

    def _triggerSecondary(self, node, offset, tid, cell_list):
        # Notify secondary storages that they can replicate from
        # primary ones, even if they are already replicating.
        p = Packets.Replicate(tid, '', {offset: node.getAddress()})
        for cell in cell_list:
            if max(cell.backup_tid, cell.replicating) < tid:
                cell.replicating = tid
                logging.debug(
                    "ask %s to replicate partition %u up to %s from %s",
                    uuid_str(cell.getUUID()), offset,
                    dump(tid), uuid_str(node.getUUID()))
                cell.getNode().send(p)