#!/usr/bin/env python3






import argparse
import os
import yaml
from tabulate import tabulate

from ec2_cluster.utils import get_dlamis, get_my_amis, get_config_params
from ec2_cluster.infra import ConfigCluster
from ec2_cluster.control import ClusterShell
from ec2_cluster.orch import set_up_passwordless_ssh_from_master_to_workers

















# Translate `param_type` string in clusterdef_params to Python type for argparse arguments
def parse_type(type_as_str):
    if type_as_str == "str":
        return str
    elif type_as_str == "int":
        return int
    elif type_as_str == "float":
        return float
    elif type_as_str == "bool":
        return bool
    elif type_as_str == "list":
        return list
    else:
        raise RuntimeError(f'Unrecognized type string: {type_as_str}')

def horovod_setup(cluster, ssh_key_path, ssh_to_private_ip=False):
    master_ip = cluster.ips["master_public_ip"] if not ssh_to_private_ip else cluster.ips["master_private_ip"]
    worker_ips = cluster.ips["worker_private_ips"]
    shell = ClusterShell(cluster.config.username,
                         master_ip,
                         worker_ips,
                         ssh_key_path=ssh_key_path,
                         use_bastion=False)
    set_up_passwordless_ssh_from_master_to_workers(shell, master_ip, worker_ips=worker_ips)


def handle_core():
    cluster_param_list = get_config_params()

    parser = argparse.ArgumentParser()

    parser.add_argument(
            "action",
            choices=["create", "delete", "describe", "ssh-cmd", "setup-horovod", "test"])

    parser.add_argument(
            "config",
            help="Path to config YAML file describing cluster")

    parser.add_argument(
            "--verbose",
            action="store_true")


    parser.add_argument(
            "--ssh_to_private_ip",
            help="Add this flag if you want the ssh-cmd or horovod-setup to use the private IP for the master instead "
                 "of the public IP. Typically this is only needed when you are running this CLI from an EC2 node "
                 "instead of your local machine.",
            action="store_true")

    parser.add_argument(
            "--clean_create",
            help="By default, create will fail if a cluster with this name already exists. With this flag, will instead"
                 "delete the existing cluster and launch a new one.",
            action="store_true")

    parser.add_argument(
            "--horovod",
            help="When creating this cluster, use --horovod to do horovod-setup after nodes are launched.",
            action="store_true")

    parser.add_argument(
            "--ssh_key_path",
            help="Absolute path to your local ssh_key. Required for horovod_setup or when using --horovod with create")

    # Add all the ClusterDef params as CLI arguments
    for param in cluster_param_list:
        parser.add_argument(
                f'--{param["param_name"]}',
                help=f'Cluster definition param. {param["param_desc"]}',
                type=parse_type(param["param_type"]))


    args, leftovers = parser.parse_known_args()

    if args.action == "test":
        args.verbose = True
        print("Tested!")
        quit()

    def vlog(s):
        if args.verbose:
            print(f'[cli_v2.py] {s}')


    # Extract any ClusterDef params from CLI arguments the user passed in
    args_as_dict = vars(args)
    cli_args = {param["param_name"]: args_as_dict[param["param_name"]]
                for param in cluster_param_list
                if args_as_dict[param["param_name"]] is not None}

    vlog(f'Command line args: {args}')
    vlog(f'ClusterDef args found in CLI args: {cli_args}')

    if args.config is None:
        config_yaml_abspath = None
    else:
        config_yaml_abspath = args.config if args.config.startswith("/") else os.path.join(os.getcwd(), args.config)

    vlog(f'Pulling yaml from: {config_yaml_abspath}')
    cluster = ConfigCluster(config_yaml_abspath, other_args=cli_args)
    vlog(f'Final config: {cluster.config}')




    vlog(f'Action = {args.action}')

    if args.action == "test":
        pass

    elif args.action == "create":
        if cluster.any_node_is_running_or_pending():
            if not args.clean_create:
                raise RuntimeError("Trying to create a cluster that already exists. Did you want to use the "
                                   "--clean_create flag?")
            else:
                vlog("clean_create: terminating existing instances")
                cluster.terminate(verbose=args.verbose)
                vlog("cleaned.")

        vlog("Launching cluster")
        cluster.launch(verbose=args.verbose)
        vlog("Cluster launched")

        if args.horovod:
            vlog("Starting Horovod setup")
            assert args.ssh_key_path is not None, "If using --horovod, you must provide --ssh_key_path"
            horovod_setup(cluster, args.ssh_key_path, ssh_to_private_ip=args.ssh_to_private_ip)
            vlog("Finished Horovod setup")



    elif args.action == "delete":
        if cluster.any_node_is_running_or_pending():
            vlog("Starting delete")
            cluster.terminate(verbose=args.verbose)
            vlog("Deletion complete")
        else:
            raise RuntimeError("Cannot terminate. The cluster does not exist.")

    elif args.action == "describe":
        vlog("describe output:")
        if not cluster.any_node_is_running_or_pending():
            print("Cluster does not exist")
        else:
            print(cluster.ips)



    elif args.action == "ssh_cmd":
        if not cluster.any_node_is_running_or_pending():
            raise RuntimeError("Cannot print SSH command. The cluster does not exist")
        else:
            ssh_ip = cluster.ips["master_public_ip"] if not args.ssh_to_private_ip else cluster.ips["master_private_ip"]
            print(f'ssh -A {cluster.config.username}:{ssh_ip}')

    elif args.action == "horovod-setup":
        if not cluster.any_node_is_running_or_pending():
            raise RuntimeError("Cannot run horovod setup. The cluster does not exist")
        else:
            vlog("Starting Horovod setup")
            assert args.ssh_key_path is not None, "If using horovod-setup, you must provide --ssh_key_path"
            horovod_setup(cluster, args.ssh_key_path, ssh_to_private_ip=args.ssh_to_private_ip)
            vlog("Finished Horovod setup")











def text_wrap(txt, width=75):
    output_lines = []
    words = txt.split(" ")

    current_line = ""
    current_width = 0
    for word in words:
        if current_width + len(word) + 1 >= width:
            output_lines.append(current_line)
            current_line = word + " "
            current_width = len(word) + 1
        else:
            current_line += word + " "
            current_width += len(word) + 1
    if current_line != "":
        output_lines.append(current_line)

    return "\n".join(output_lines)


def display_config_params(config_param_list):
    rows =[]
    for c in config_param_list:
        description = c['param_desc']
        description = text_wrap(description)
        rows.append([c['param_name'], c['param_type'], description])

    print("")
    print(tabulate(rows, headers=["Name", "Type", "Description"]))


def display_ami_list(amis):
    rows = []
    for ami in amis:
        description = ami['Description']
        description = text_wrap(description)
        rows.append([ami['Name'], ami['ImageId'], ami['SnapshotId'], description])

    print("")
    print(tabulate(rows, headers=["Name", "AMI Id", "EBS Snapshot Id", "Description"]))



def handle_utils():

    parser = argparse.ArgumentParser()

    parser.add_argument(
            "utils",
            choices=["utils"])

    parser.add_argument(
            "action",
            choices=["list-dlami", "list-ami", "describe-config"])

    parser.add_argument(
            "--dlami_type",
            help="[list-dlami] Which of the DLAMI flavors to retrieve info about - Ubuntu or Amazon Linux",
            choices=["Ubuntu", "AL"],
            default="Ubuntu"
    )

    parser.add_argument(
            "--region",
            help="[list-dlami or list-ami] Which AWS region to retrieve information about, e.g. us-east-1. Defaults to AWS CLI default region",
    )

    args, leftovers = parser.parse_known_args()

    if args.action == "list-dlami":
        dlamis = get_dlamis(region=args.region, ami_type=args.dlami_type)
        display_ami_list(dlamis)

    elif args.action == "list-ami":
        amis = get_my_amis(region=args.region)
        display_ami_list(amis)

    elif args.action == "describe-config":
        config_param_list = get_config_params()
        display_config_params(config_param_list)





if __name__ == '__main__':

    parser = argparse.ArgumentParser(add_help=False)

    parser.add_argument("action",
                        choices=["help",
                                 "create",
                                 "delete",
                                 "describe",
                                 "ssh-cmd",
                                 "test",
                                 "setup-horovod",
                                 "utils"])

    args, leftovers = parser.parse_known_args()

    if args.action == 'help':
        help_output = "\n"\
                      "Command line utility for working with clusters of EC2 instances, primarily for deep learning.\n"\
                      "Available commands are: \n" \
                      "     create: Create a cluster\n" \
                      "     delete: Delete a cluster\n" \
                      "     describe: List the public and private IPs of the nodes in the cluster\n" \
                      "     ssh-cmd: Print command to ssh to master node\n" \
                      "     setup-horovod: Setup ssh in the cluster to enable Horovod. Can be done at create time\n" \
                      "                    with --horovod flag\n" \
                      "     utils: Various utilities such as listing the details of the DLAMIs or listing details\n" \
                      "            of your AMIs\n"
        print(help_output)
        quit()
    else:
        if args.action in ['create', 'delete', 'describe', 'ssh-cmd', 'setup-horovod', 'test']:
            handle_core()
        elif args.action == 'utils':
            handle_utils()
        else:
            print(f'Unrecognized action "{args.action}". This error should have been caught by argparse and this '
                  f'line should not have been printed.')


