#!/usr/bin/env python3
#
# Copied with minimal modifications from curio
# https://github.com/dabeaz/curio


import argparse
from concurrent import futures
import socket
import time

import numpy as np


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--msize', default=1000, type=int,
                        help='message size in bytes')
    parser.add_argument('--duration', '-T', default=30, type=int,
                        help='duration of test in seconds')
    parser.add_argument('--times', default=1, type=int,
                        help='number of times to run the test')
    parser.add_argument('--concurrency', default=3, type=int,
                        help='request concurrency')
    parser.add_argument('--timeout', default=2, type=int,
                        help='socket timeout in seconds')
    parser.add_argument('--addr', default='127.0.0.1:25000', type=str,
                        help='server address')
    args = parser.parse_args()

    unix = False
    if args.addr.startswith('file:'):
        unix = True
        addr = args.addr[5:]
    else:
        addr = args.addr.split(':')
        addr[1] = int(addr[1])
        addr = tuple(addr)

    MSGSIZE = args.msize

    msg = b'x' * MSGSIZE

    timeout = args.timeout * 1000

    def run_test(start, duration):
        if unix:
            sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        else:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

        sock.settimeout(5)
        sock.connect(addr)

        n = 0
        latency_stats = np.zeros((timeout,))

        while time.monotonic() - start < duration:
            req_start = time.monotonic()
            sock.sendall(msg)
            nrecv = 0
            while nrecv < MSGSIZE:
                resp = sock.recv(MSGSIZE)
                if not resp:
                    raise SystemExit()
                nrecv += len(resp)
            req_time = round((time.monotonic() - req_start) * 1000)
            latency_stats[req_time] += 1
            n += 1

        return n, latency_stats

    TIMES = args.times
    N = args.concurrency
    DURATION = args.duration

    messages = 0
    latency_stats = None
    start = time.monotonic()
    for _ in range(TIMES):
        with futures.ProcessPoolExecutor(max_workers=N) as e:
            fs = []
            for _ in range(N):
                fs.append(e.submit(run_test, start, DURATION))

            res = futures.wait(fs)
            for fut in res.done:
                t_messages, t_latency_stats = fut.result()
                messages += t_messages
                if latency_stats is None:
                    latency_stats = t_latency_stats
                else:
                    latency_stats = np.add(latency_stats, t_latency_stats)

    end = time.monotonic()
    duration = end - start

    weighted_latency = np.multiply(latency_stats, np.arange(timeout))

    mean_latency = (np.sum(weighted_latency) / timeout)

    trimmed_latency = np.trim_zeros(latency_stats, 'b')

    percentiles = [50, 75, 90, 99]
    percentile_data = []

    latency_chart = np.stack((np.arange(len(trimmed_latency)),
                              trimmed_latency), axis=-1)
    percentile_values = np.percentile(latency_chart, percentiles, axis=0)

    for i, percentile in enumerate(percentiles):
        percentile_data.append('{}%: {}ms'.format(
            percentile, round(percentile_values[i][0], 2)))

    print(messages, 'in', round(duration, 2))
    print('Latency avg: {}ms'.format(round(mean_latency, 2)))
    # print('Latency distribution: {}'.format('; '.join(percentile_data)))
    print('Requests/sec: {}'.format(round(messages / duration, 2)))
    transfer = (messages * MSGSIZE / (1024 * 1024)) / duration
    print('Transfer/sec: {}MiB'.format(round(transfer, 2)))