#! /usr/bin/env python3
"""
Convenience script for gathering metrics data

"""
import argparse
import logging
import os
import shutil
from datetime import datetime, timedelta

from indy_common.config_util import getConfig
from plenum.common.constants import KeyValueStorageType
from plenum.common.metrics_collector import MetricsName
from plenum.common.metrics_stats import load_metrics_from_kv_store
from storage.helper import initKeyValueStorage

logging.root.handlers = []
logger = logging.getLogger()
logger.propagate = False
logger.disabled = True


def read_args():
    parser = argparse.ArgumentParser(description="Process metrics")

    parser.add_argument('--data_dir', required=False, help="Path to metrics data, replaces a")
    parser.add_argument('--node_name', required=False, help="Node's name")
    parser.add_argument('--network', required=False, help="Network name to read metrics from")
    parser.add_argument('--output', required=False, default="metrics.csv", help="Output CSV file name with metrics")

    parser.add_argument(
        '--min_ts',
        type=datetime,
        required=False,
        default=None,
        help="Process all metrics starting from given timestamp (beginning by default)")
    parser.add_argument(
        '--max_ts',
        type=datetime,
        required=False,
        default=None,
        help="Process all metrics till given timestamp (end by default)")
    parser.add_argument(
        '--step',
        type=int,
        required=False,
        default=60,
        help="Build timelog with given timestep, seconds (60 by default)"
    )

    return parser.parse_args()


def get_data_dir(node_name=None, network=None):
    config = getConfig()
    _network = network if network else config.NETWORK_NAME
    data_dir = os.path.join(config.LEDGER_DIR, _network, 'data')

    if not os.path.exists(data_dir):
        print("No such file or directory: {}".format(data_dir))
        print("Please check, that network: '{}' was used ".format(_network))
        exit()

    if not node_name:
        dirs = os.listdir(data_dir)
        if len(dirs) == 0:
            print("Node's 'data' folder not found: {}".format(data_dir))
            exit()
        node_name = dirs[0]

    return os.path.join(data_dir, node_name)


def make_copy_of_data(data_dir):
    read_copy_data_dir = data_dir + '-read-copy'
    if os.path.exists(read_copy_data_dir):
        shutil.rmtree(read_copy_data_dir)
    shutil.copytree(data_dir, read_copy_data_dir)
    return read_copy_data_dir


def process_storage(storage, args):
    stats = load_metrics_from_kv_store(storage, args.min_ts, args.max_ts, timedelta(seconds=args.step))

    with open(args.output, 'w') as f:
        columns = [
            "timestamp",
            "avg_looper_run_time_spent",
            "avg_three_pc_batch_size",
            "avg_incoming_node_message_size",
            "avg_outgoing_node_message_size",
            "avg_incoming_client_message_size",
            "avg_outgoing_client_message_size",
            "node_stack_messages_processed_per_sec",
            "client_stack_messages_processed_per_sec",
            "three_pc_batch_count_per_sec",
            "requests_count_per_sec",
            "incoming_node_messages_per_sec",
            "incoming_node_traffic_per_sec",
            "outgoing_node_messages_per_sec",
            "outgoing_node_traffic_per_sec",
            "incoming_client_messages_per_sec",
            "incoming_client_traffic_per_sec",
            "outgoing_client_messages_per_sec",
            "outgoing_client_traffic_per_sec"
        ]
        f.write(",".join(columns))
        f.write("\n")
        for ts, frame in sorted(stats.frames(), key=lambda v: v[0]):
            row = [
                ts.replace(microsecond=0),
                frame.get(MetricsName.LOOPER_RUN_TIME_SPENT).avg,
                frame.get(MetricsName.THREE_PC_BATCH_SIZE).avg,
                frame.get(MetricsName.INCOMING_NODE_MESSAGE_SIZE).avg,
                frame.get(MetricsName.OUTGOING_NODE_MESSAGE_SIZE).avg,
                frame.get(MetricsName.INCOMING_CLIENT_MESSAGE_SIZE).avg,
                frame.get(MetricsName.OUTGOING_CLIENT_MESSAGE_SIZE).avg,
                frame.get(MetricsName.NODE_STACK_MESSAGES_PROCESSED).sum / stats.timestep.total_seconds(),
                frame.get(MetricsName.CLIENT_STACK_MESSAGES_PROCESSED).sum / stats.timestep.total_seconds(),
                frame.get(MetricsName.THREE_PC_BATCH_SIZE).count / stats.timestep.total_seconds(),
                frame.get(MetricsName.THREE_PC_BATCH_SIZE).sum / stats.timestep.total_seconds(),
                frame.get(MetricsName.INCOMING_NODE_MESSAGE_SIZE).count / stats.timestep.total_seconds(),
                frame.get(MetricsName.INCOMING_NODE_MESSAGE_SIZE).sum / stats.timestep.total_seconds(),
                frame.get(MetricsName.OUTGOING_NODE_MESSAGE_SIZE).count / stats.timestep.total_seconds(),
                frame.get(MetricsName.OUTGOING_NODE_MESSAGE_SIZE).sum / stats.timestep.total_seconds(),
                frame.get(MetricsName.INCOMING_CLIENT_MESSAGE_SIZE).count / stats.timestep.total_seconds(),
                frame.get(MetricsName.INCOMING_CLIENT_MESSAGE_SIZE).sum / stats.timestep.total_seconds(),
                frame.get(MetricsName.OUTGOING_CLIENT_MESSAGE_SIZE).count / stats.timestep.total_seconds(),
                frame.get(MetricsName.OUTGOING_CLIENT_MESSAGE_SIZE).sum / stats.timestep.total_seconds()
            ]
            f.write(",".join(str(v if v else "0") for v in row))
            f.write("\n")

    print("Start time: {}".format(stats.min_ts))
    print("End time: {}".format(stats.max_ts))
    print("Duration: {}".format(stats.max_ts - stats.min_ts))
    print("")

    total = stats.total

    print("Number of node messages processed in one looper run:")
    print("   {}".format(total.get(MetricsName.NODE_STACK_MESSAGES_PROCESSED)))
    print("")

    print("Number of client messages processed in one looper run:")
    print("   {}".format(total.get(MetricsName.CLIENT_STACK_MESSAGES_PROCESSED)))
    print("")

    print("Seconds passed between looper runs:")
    print("   {}".format(total.get(MetricsName.LOOPER_RUN_TIME_SPENT)))
    print("")

    print("Number of requests in one 3PC batch:")
    print("   {}".format(total.get(MetricsName.THREE_PC_BATCH_SIZE)))
    print("")

    print("Number of messages in one tranport batch:")
    print("   {}".format(total.get(MetricsName.TRANSPORT_BATCH_SIZE)))
    print("")

    print("Incoming node message size, bytes:")
    print("   {}".format(total.get(MetricsName.INCOMING_NODE_MESSAGE_SIZE)))
    print("")

    print("Outgoing node message size, bytes:")
    print("   {}".format(total.get(MetricsName.OUTGOING_NODE_MESSAGE_SIZE)))
    print("")

    print("Incoming client message size, bytes:")
    print("   {}".format(total.get(MetricsName.INCOMING_CLIENT_MESSAGE_SIZE)))
    print("")

    print("Outgoing client message size, bytes:")
    print("   {}".format(total.get(MetricsName.OUTGOING_CLIENT_MESSAGE_SIZE)))
    print("")

    print("Additional statistics:")
    three_pc = total.get(MetricsName.THREE_PC_BATCH_SIZE)
    node_in = total.get(MetricsName.INCOMING_NODE_MESSAGE_SIZE)
    node_out = total.get(MetricsName.OUTGOING_NODE_MESSAGE_SIZE)
    client_in = total.get(MetricsName.INCOMING_CLIENT_MESSAGE_SIZE)
    client_out = total.get(MetricsName.OUTGOING_CLIENT_MESSAGE_SIZE)
    node_messages = node_in.count + node_out.count
    node_traffic = node_in.sum + node_out.sum
    client_messages = (client_in.count + client_out.count)
    client_traffic = (client_in.sum + client_out.sum)
    print("   Node incoming/outgoing: {:.2f} messages, {:.2f} traffic"
          .format(node_in.count / node_out.count, node_in.sum / node_out.sum))
    print("   Client incoming/outgoing: {:.2f} messages, {:.2f} traffic"
          .format(client_in.count / client_out.count, client_in.sum / client_out.sum))
    print("   Node/client: {:.2f} messages, {:.2f} traffic"
          .format(node_messages / client_messages, node_traffic / client_traffic))
    print("   Node messages per batch: {:.2f}".format(node_messages / three_pc.count))
    print("   Node traffic per batch: {:.2f}".format(node_traffic / three_pc.count))
    print("   Node messages per request: {:.2f}".format(node_messages / three_pc.sum))
    print("   Node traffic per request: {:.2f}".format(node_traffic / three_pc.sum))


if __name__ == '__main__':
    args = read_args()

    if args.data_dir is not None:
        storage = initKeyValueStorage(KeyValueStorageType.Rocksdb, args.data_dir, "")
        process_storage(storage, args)
        exit()

    data_dir = get_data_dir(args.node_name, args.network)
    data_is_copied = False
    try:
        config = getConfig()
        if config.METRICS_KV_STORAGE != KeyValueStorageType.Rocksdb:
            data_dir = make_copy_of_data(data_dir)
            data_is_copied = True

        storage = initKeyValueStorage(
                    config.METRICS_KV_STORAGE,
                    data_dir,
                    config.METRICS_KV_DB_NAME)

        process_storage(storage, args)
    finally:
        if data_is_copied:
            shutil.rmtree(data_dir)
