#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import asyncio
import concurrent.futures
import curses
import datetime
import json
import os
import sys
import time
import subprocess
from collections import OrderedDict
from glob import glob

from sovrin_common.config_util import getConfig
from stp_core.common.log import Logger
from stp_core.common.log import getlogger

config = getConfig()

logger = None  # to make flake8 happy
clients = {}   # task -> (reader, writer)


class BaseStats(OrderedDict):
    shema = []

    def __init__(self, stats, verbose=False):
        for k, cls in self.shema:
            # TODO improve defaults processing
            try:
                self[k] = cls(stats.get(k))
            except Exception as e:
                logger.warning(
                    "{} Failed to parse attribute '{}': {}".format(
                        type(self).__name__, k, e))
                self[k] = None

        self._verbose = verbose


class ConnectionStats(BaseStats):
    shema = [
        ("ip", str),
        ("port", int),
        ("protocol", str)
    ]


class BindingsStats(BaseStats):
    shema = [
        ("node", ConnectionStats),
        ("client", ConnectionStats)
    ]


class TransactionsStats(BaseStats):
    shema = [
        ("config", int),
        ("ledger", int),
        ("pool", int)
    ]


class AverageStats(BaseStats):
    shema = [
        ("read-transactions", int),
        ("write-transactions", int)
    ]


class MetricsStats(BaseStats):
    shema = [
        ("uptime", int),
        ("transaction-count", TransactionsStats),
        ("average-per-second", AverageStats)
    ]


class NodesListStats(BaseStats):
    shema = [
        ("count", int),
        ("list", list)
    ]


class PoolStats(BaseStats):
    shema = [
        ("total-count", int),
        ("reachable", NodesListStats),
        ("unreachable", NodesListStats)
    ]


class SoftwareStats(BaseStats):
    shema = [
        ("indy-node", str),
        ("sovrin", str)
    ]


class ValidatorStats(BaseStats):
    shema = [
        ("response-version", str),
        ("timestamp", int),
        ("alias", str),
        ("state", str),
        ("enabled", str),
        ("did", str),
        ("verkey", str),
        ("bindings", BindingsStats),
        ("metrics", MetricsStats),
        ("pool", PoolStats),
        ("software", SoftwareStats)
    ]

    def __str__(self):
        # TODO moving parts to other classes seems reasonable but
        # will drop visibility of output

        def format_port(port, protocol, ip):
            return "{}{}".format(
                port,
                "/{} on {}".format(protocol, ip) if self._verbose else ""
            )

        def format_uptime(secs):
            days, remainder = divmod(secs, 86400)
            hours, remainder = divmod(remainder, 3600)
            minutes, seconds = divmod(remainder, 60)
            parts = []

            for s, v in zip(['day', 'hour', 'minute', 'second'],
                            [days, hours, minutes, seconds]):
                if v or len(parts):
                    parts.append("{} {}{}".format(v, s, '' if v == 1 else 's'))

            return ", ".join(parts) if parts else '0 seconds'

        lines = [
            "Validator {} is {}".format(
                self['alias'], self['state']),
            "#Current time:     {}".format(
                datetime.datetime.fromtimestamp(
                    self['timestamp']).strftime(
                        "%A, %B %{0}d, %Y %{0}I:%M:%S %p".format(
                            '#' if os.name == 'nt' else '-'))
            ),
            "Validator DID:    {}".format(self['did']),
            "Verification Key: {}".format(self['verkey']),
            "Node Port:        {}".format(
                format_port(
                    self['bindings']['node']['port'],
                    self['bindings']['node']['protocol'],
                    self['bindings']['node']['ip'])
            ),
            "Client Port:      {}".format(
                format_port(
                    self['bindings']['client']['port'],
                    self['bindings']['client']['protocol'],
                    self['bindings']['client']['ip'])
            ),
            "Metrics:",
            "  Uptime: {}".format(format_uptime(self['metrics']['uptime'])),
            "#  Total Config Transactions:  {}".format(
                self['metrics']['transaction-count']['config']),
            "  Total Ledger Transactions:  {}".format(
                self['metrics']['transaction-count']['ledger']),
            "  Total Pool Transactions:    {}".format(
                self['metrics']['transaction-count']['pool']),
            "  Read Transactions/Seconds:  {}".format(
                self['metrics']['average-per-second']['read-transactions']),
            "  Write Transactions/Seconds: {}".format(
                self['metrics']['average-per-second']['write-transactions']),
            "Reachable Hosts:   {}/{}".format(
                self['pool']['reachable']['count'],
                self['pool']['total-count'])
        ] + [
            "#  {}".format(alias)
            for alias in self['pool']['reachable']['list']
        ] + [
            "Unreachable Hosts: {}/{}".format(
                self['pool']['unreachable']['count'],
                self['pool']['total-count']
            )
        ] + [
            "#  {}".format(alias)
            for alias in self['pool']['unreachable']['list']
        ] + [
            "#Software Versions:",
            "#  indy-node: 1.0.28".format(self['software']['indy-node']),
            "#  sovrin:    1.0.3".format(self['software']['sovrin'])
        ]

        # skip lines with started with '#' if not verbose
        # or remove '#' otherwise
        return ("\n".join(
            [l[(1 if l[0] == '#' else 0):]
                for l in lines if self._verbose or l[0] != '#'])
        )


async def handle_client(client_reader, client_writer):
    # give client a chance to respond, timeout after 10 seconds
    while True:
        try:
            data = await client_reader.readline()
        except concurrent.futures.CancelledError:
            logger.warning("task has been cancelled")
            return
        except Exception as e:
            logger.exception("failed to readline")
            return
        else:
            if data is None:
                logger.warning("Expected data, received None")
                return
            elif not data:
                logger.warning("EOF received, closing connection")
                return
            else:
                logger.debug("Received data: {}".format(data))
                stats = json.loads(data.decode())
                print(json.dumps(stats, indent=2))


def accept_client(client_reader, client_writer):
    logger.info("New Connection")
    task = asyncio.Task(handle_client(client_reader, client_writer))

    clients[task] = (client_reader, client_writer)

    def client_done(task):
        del clients[task]
        client_writer.close()
        logger.info("End Connection")

    task.add_done_callback(client_done)


def get_stats_from_file(fpath, verbose, _json):
    with open(fpath) as f:
        stats = json.loads(f.read())

    logger.debug("Data {}".format(stats))
    vstats = ValidatorStats(stats, verbose)
    vstats.update(state=get_process_state())
    vstats.update(enabled=get_enabled_state())

    return (json.dumps(vstats, indent=2) if _json else vstats)


def get_process_state():
    ret = subprocess.check_output('systemctl is-failed sovrin-node; exit 0', stderr=subprocess.STDOUT, shell=True)
    ret = ret.decode().strip()
    if ret == 'inactive':
        return 'stopped'
    elif ret == 'active':
        return 'running'
    else:
        return ret


def get_enabled_state():
    ret = subprocess.check_output('systemctl is-enabled sovrin-node; exit 0', stderr=subprocess.STDOUT, shell=True)
    ret = ret.decode().strip()
    if ret == 'enabled' or ret == 'static':
        return True
    elif ret == 'disabled':
        return False
    else:
        return ret


def watch(fpath, verbose, _json):

    def _watch(stdscr):
        stats = None
        while True:
            # Clear screen
            stdscr.clear()

            if stats is not None:
                time.sleep(3)
            stdscr.addstr(0, 0, str(get_stats_from_file(fpath, verbose, _json)))
            stdscr.refresh()

    try:
        curses.wrapper(_watch)
    except KeyboardInterrupt:
        pass

def main():
    global logger

    def check_unsigned(s):
        res = None
        try:
            res = int(s)
        except ValueError:
            pass
        if res is None or res <= 0:
            raise argparse.ArgumentTypeError(("{!r} is incorrect, "
                                              "should be int > 0").format(s,))
        else:
            return res

    parser = argparse.ArgumentParser(
        description=(
            "Tool to explore and gather statistics about running validator"
        ),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    parser.add_argument(
        "-v", "--verbose", action="store_true",
        help="Verbose mode (command line)"
    )
    parser.add_argument(
        "--json", action="store_true",
        help="Format output as JSON (ignores -v)"
    )

    statfile_group = parser.add_argument_group(
        "statfile", "settings for exploring validator stats from stat file"
    )

    statfile_group.add_argument(
        "--basedir", metavar="PATH",
        default=config.baseDir,
        help=("Path to stats files")
    )
    #statfile_group.add_argument(
    #    "--watch", action="store_true", help="Watch for stats file updates"
    #)

    # socket group is disabled for now due the feature is unsupported
    # socket_group = parser.add_argument_group(
    #     "socket", "settings for exploring validator stats from socket"
    # )
    #
    # socket_group.add_argument(
    #     "--listen", action="store_true",
    #     help="Listen socket for stats (ignores --statfile)"
    # )
    #
    # socket_group.add_argument(
    #     "-i", "--ip", metavar="IP", default=config.STATS_SERVER_IP,
    #     help="Server IP"
    # )
    # socket_group.add_argument(
    #     "-p", "--port", metavar="PORT", default=config.STATS_SERVER_PORT,
    #     type=check_unsigned, help="Server port"
    # )

    other_group = parser.add_argument_group(
        "other", "other settings"
    )

    other_group.add_argument("--stdlog", action="store_true",
                             help="Enable logging to stdout")

    args = parser.parse_args()

    config.enableStdOutLogging = args.stdlog
    logFileName = os.path.join(config.baseDir,
                               os.path.basename(sys.argv[0] + ".log"))

    logger = getlogger()
    Logger().enableFileLogging(logFileName)

    logger.debug("Cmd line arguments: {}".format(args))

    # is not supported for now
    # if args.listen:
    #     logger.info("Starting server on {}:{} ...".format(args.ip, args.port))
    #     print("Starting server on {}:{} ...".format(args.ip, args.port))
    #
    #     loop = asyncio.get_event_loop()
    #     coro = asyncio.start_server(accept_client,
    #                                 args.ip, args.port, loop=loop)
    #     server = loop.run_until_complete(coro)
    #
    #     logger.info("Serving on {}:{} ...".format(args.ip, args.port))
    #     print('Serving on {} ...'.format(server.sockets[0].getsockname()))
    #
    #     # Serve requests until Ctrl+C is pressed
    #     try:
    #         loop.run_forever()
    #     except KeyboardInterrupt:
    #         pass
    #
    #     logger.info("Stopping server ...")
    #     print("Stopping server ...")
    #
    #     # Close the server
    #     server.close()
    #     for task in clients.keys():
    #         task.cancel()
    #     loop.run_until_complete(server.wait_closed())
    #     loop.close()
    # else:
    paths = glob(os.path.join(args.basedir, "*_info.json"))
    if not paths:
        print('There are no info files in {}'.format(args.basedir))
        return
    for file_path in paths:
        logger.info("Reading file {} ...".format(file_path))
        print(get_stats_from_file(file_path, args.verbose, args.json))
        print('\n')

    logger.info("Done")


if __name__ == "__main__":
    sys.exit(main())
