#!python

import argparse
import os
from tabulate import tabulate

from ec2_cluster.utils import get_dlamis, get_my_amis, get_config_params
from ec2_cluster.infra import ConfigCluster
from ec2_cluster.control import ClusterShell
from ec2_cluster.orch import set_up_passwordless_ssh_from_master_to_workers, generate_hostfile

















# Translate `param_type` string in clusterdef_params to Python type for argparse arguments
def parse_type(type_as_str):
    if type_as_str == "str":
        return str
    elif type_as_str == "int":
        return int
    elif type_as_str == "float":
        return float
    elif type_as_str == "bool":
        return bool
    elif type_as_str == "list":
        return list
    else:
        raise RuntimeError(f'Unrecognized type string: {type_as_str}')

def horovod_setup(cluster, ssh_key, slots, ssh_to_private_ip=False):

    if ssh_key.startswith("/"):
        ssh_key_path = ssh_key
    else:
        if not ssh_key.endswith('.pem'):
            ssh_key += '.pem'
        ssh_key_path_dir = os.path.expanduser('~/.ssh/')
        ssh_key_path = os.path.join(ssh_key_path_dir, ssh_key)

    assert os.path.isfile(ssh_key_path), f'Did not find an SSH key at ${ssh_key_path}\n' \
                                         f'If the key is not located at ~/.ssh/$KEY_NAME.pem, use the --ssh_key flag to ' \
                                         f'specify the absolute path to the key'

    master_ip = cluster.ips["master_public_ip"] if not ssh_to_private_ip else cluster.ips["master_private_ip"]
    worker_ips = cluster.ips["worker_private_ips"]
    shell = ClusterShell(cluster.config.username,
                         master_ip,
                         worker_ips,
                         ssh_key_path=ssh_key_path,
                         use_bastion=False)
    set_up_passwordless_ssh_from_master_to_workers(shell, master_ip, worker_ips=worker_ips, verbose=True)

    generate_hostfile(cluster, shell, slots, use_localhost=False)





def handle_core():
    cluster_param_list = get_config_params()

    parser = argparse.ArgumentParser()

    parser.add_argument(
            "action",
            choices=["create", "delete", "describe", "ssh-cmd", "setup-horovod", "test"])

    parser.add_argument(
            "config",
            help="Path to config YAML file describing cluster")

    parser.add_argument(
            "--verbose",
            action="store_true")


    parser.add_argument(
            "--ssh_to_private_ip",
            help="Add this flag if you want the ssh-cmd or horovod-setup to use the private IP for the master instead "
                 "of the public IP. Typically this is only needed when you are running this CLI from an EC2 node "
                 "instead of your local machine.",
            action="store_true")

    parser.add_argument(
            "--clean_create",
            help="By default, create will fail if a cluster with this name already exists. With this flag, will instead"
                 "delete the existing cluster and launch a new one.",
            action="store_true")

    parser.add_argument(
            "--horovod",
            help="When creating this cluster, use --horovod to do horovod-setup after nodes are launched.",
            action="store_true")

    parser.add_argument(
            "--ssh_key",
            help="Name of your ssh key. If not an absolute path, assumes key is located at ~/.ssh/$SSH_KEY.pem Absolute "
                 "path to your local ssh_key. Required for horovod_setup or when using --horovod with create")

    parser.add_argument(
            "--slots",
            help="For setup-horovod or create --horovod, the number of slots on each host to put in the hostfile. "
                 "Defaults to 8, correct for p3.16xl and p3dn.24xl",
            default=8)

    # Add all the ClusterDef params as CLI arguments
    for param in cluster_param_list:
        parser.add_argument(
                f'--{param["param_name"]}',
                help=f'Cluster definition param. {param["param_desc"]}',
                type=parse_type(param["param_type"]))


    args, leftovers = parser.parse_known_args()

    if args.action == "test":
        args.verbose = True
        print("Tested!")
        quit()

    def vlog(s):
        if args.verbose:
            print(f'[ec3-verbose] {s}')

    def log(s):
        print(f'[ec3] {s}')


    # Extract any ClusterDef params from CLI arguments the user passed in
    args_as_dict = vars(args)
    cli_args = {param["param_name"]: args_as_dict[param["param_name"]]
                for param in cluster_param_list
                if args_as_dict[param["param_name"]] is not None}

    vlog(f'Command line args: {args}')
    vlog(f'ClusterDef args found in CLI args: {cli_args}')

    if args.config is None:
        config_yaml_abspath = None
    else:
        config_yaml_abspath = args.config if args.config.startswith("/") else os.path.join(os.getcwd(), args.config)

    vlog(f'Pulling yaml from: {config_yaml_abspath}')
    cluster = ConfigCluster(config_yaml_abspath, other_args=cli_args)
    vlog(f'Final config: {cluster.config}')

    # If you don't explicitly pass in the ssh_key_path, will default to the key_name in config.yaml, located at ~/.ssh/$KEY_NAME.pem
    if args.ssh_key is None:
        args.ssh_key = cluster.config.key_name



    vlog(f'Action = {args.action}')

    if args.action == "test":
        pass

    elif args.action == "create":
        if cluster.any_node_is_running_or_pending():
            if not args.clean_create:
                raise RuntimeError("Trying to create a cluster that already exists. Did you want to use the "
                                   "--clean_create flag?")
            else:
                vlog("clean_create: terminating existing instances")
                cluster.terminate(verbose=True) # Verbose here outputs important information instead of no output.
                vlog("cleaned.")

        log("Launching cluster")
        cluster.launch(verbose=True) # Verbose here outputs important information instead of no output.
        log("Cluster launched")

        if args.horovod:
            try:
                log("Starting Horovod setup")
                horovod_setup(cluster, args.ssh_key, args.slots, ssh_to_private_ip=args.ssh_to_private_ip)
                log("Finished Horovod setup")
            except Exception as ex:
                if "[Errno None] Unable to connect to port 22" in str(ex):
                    # This happens sometimes, probably due to the node not quite being ready. Retry
                    log("Unable to connect, probably because the node wasn't quite ready. Retrying")
                    horovod_setup(cluster, args.ssh_key, args.slots, ssh_to_private_ip=args.ssh_to_private_ip)
                    log("Finished Horovod setup")




    elif args.action == "delete":
        if cluster.any_node_is_running_or_pending():
            log("Starting delete")
            cluster.terminate(verbose=True) # Verbose here outputs important information instead of no output.
            log("Deletion complete")
        else:
            raise RuntimeError("Cannot terminate. The cluster does not exist.")

    elif args.action == "describe":
        vlog("describe")
        if not cluster.any_node_is_running_or_pending():
            print("Cluster does not exist")
        else:
            print(cluster.ips)



    elif args.action == "ssh-cmd":
        if not cluster.any_node_is_running_or_pending():
            raise RuntimeError("Cannot print SSH command. The cluster does not exist")
        else:
            ssh_ip = cluster.ips["master_public_ip"] if not args.ssh_to_private_ip else cluster.ips["master_private_ip"]
            print(f'ssh -A {cluster.config.username}@{ssh_ip}')

    elif args.action == "setup-horovod":
        if not cluster.any_node_is_running_or_pending():
            raise RuntimeError("Cannot run horovod setup. The cluster does not exist")
        else:

            log("Starting Horovod setup")
            horovod_setup(cluster, args.ssh_key, args.slots, ssh_to_private_ip=args.ssh_to_private_ip)
            log("Finished Horovod setup")











def text_wrap(txt, width=75):
    output_lines = []
    words = txt.split(" ")

    current_line = ""
    current_width = 0
    for word in words:
        if current_width + len(word) + 1 >= width:
            output_lines.append(current_line)
            current_line = word + " "
            current_width = len(word) + 1
        else:
            current_line += word + " "
            current_width += len(word) + 1
    if current_line != "":
        output_lines.append(current_line)

    return "\n".join(output_lines)


def display_config_params(config_param_list):
    rows =[]
    for c in config_param_list:
        description = c['param_desc']
        description = text_wrap(description)
        rows.append([c['param_name'], c['param_type'], description])

    print("")
    print(tabulate(rows, headers=["Name", "Type", "Description"]))


def display_ami_list(amis):
    rows = []
    for ami in amis:
        description = ami['Description']
        description = text_wrap(description)
        rows.append([ami['Name'], ami['ImageId'], ami['SnapshotId'], description])

    print("")
    print(tabulate(rows, headers=["Name", "AMI Id", "EBS Snapshot Id", "Description"]))



def handle_utils():

    parser = argparse.ArgumentParser()

    parser.add_argument(
            "utils",
            choices=["utils"])

    parser.add_argument(
            "action",
            choices=["list-dlami", "list-ami", "describe-config"])

    parser.add_argument(
            "--dlami_type",
            help="[list-dlami] Which of the DLAMI flavors to retrieve info about - Ubuntu or Amazon Linux",
            choices=["Ubuntu", "AL"],
            default="Ubuntu"
    )

    parser.add_argument(
            "--region",
            help="[list-dlami or list-ami] Which AWS region to retrieve information about, e.g. us-east-1. Defaults to AWS CLI default region",
    )

    args, leftovers = parser.parse_known_args()

    if args.action == "list-dlami":
        if args.dlami_type == 'AL':
            dlami_type = "Amazon Linux"
        else:
            dlami_type = args.dlami_type
        dlamis = get_dlamis(region=args.region, ami_type=dlami_type)
        display_ami_list(dlamis)

    elif args.action == "list-ami":
        amis = get_my_amis(region=args.region)
        display_ami_list(amis)

    elif args.action == "describe-config":
        config_param_list = get_config_params()
        display_config_params(config_param_list)





if __name__ == '__main__':

    parser = argparse.ArgumentParser(add_help=False)

    parser.add_argument("action",
                        choices=["help",
                                 "create",
                                 "delete",
                                 "describe",
                                 "ssh-cmd",
                                 "test",
                                 "setup-horovod",
                                 "utils"])

    args, leftovers = parser.parse_known_args()

    if args.action == 'help':
        help_output = "\n"\
                      "Command line utility for working with clusters of EC2 instances, primarily for deep learning.\n"\
                      "Available commands are: \n" \
                      "     create: Create a cluster\n" \
                      "     delete: Delete a cluster\n" \
                      "     describe: List the public and private IPs of the nodes in the cluster\n" \
                      "     ssh-cmd: Print command to ssh to master node\n" \
                      "     setup-horovod: Setup ssh in the cluster to enable Horovod. Can be done at create time\n" \
                      "                    with --horovod flag\n" \
                      "     utils: Various utilities such as listing the details of the DLAMIs or listing details\n" \
                      "            of your AMIs\n" \
                      "\n" \
                      "Check help for individual commands (ec3 cmd -h) to see more details."
        print(help_output)
        quit()
    else:
        if args.action in ['create', 'delete', 'describe', 'ssh-cmd', 'setup-horovod', 'test']:
            handle_core()
        elif args.action == 'utils':
            handle_utils()
        else:
            print(f'Unrecognized action "{args.action}". This error should have been caught by argparse and this '
                  f'line should not have been printed.')


