#!/usr/bin/env python

import os
import sys
import argparse
import importlib
import itertools
from typing import Any, Mapping, Iterable, Sequence

from submititnow import options
from submititnow import experiment_lib
from submititnow.umiacs import handlers

sys.path.insert(0, os.getcwd())


def get_module_name(src_file: str):
    assert src_file.endswith(".py"), f"'{src_file}' is not a Python source file."
    return ".".join(src_file.rsplit(".")[0].split("/"))


def kw_product(params: Mapping[str, Any]):
    if not params:
        yield {}
        return
    keys, values = zip(*params.items())
    for bundle in itertools.product(*values):
        d = dict(zip(keys, bundle))
        yield d


def make_args_sweepable(parser: argparse.ArgumentParser, sweep_args: Iterable[str]):
    for action in parser._actions:
        if action.dest in sweep_args:
            if action.nargs == "*":
                parser.error(
                    f"Cannot sweep over variable {action.dest} because it is a list."
                )
            action.nargs = "+"
            action.required = True
    return parser


def __merge_dicts(dict1, dict2):
    new_dict = dict(dict1)
    new_dict.update(dict2)
    return new_dict


def create_module_args_list(
    module_argparser: argparse.ArgumentParser,
    sweep_args: Iterable[str],
    downstream_args: Sequence[str],
):
    make_args_sweepable(module_argparser, sweep_args)

    module_args_with_sweeps = module_argparser.parse_args(downstream_args)

    module_args_dict = vars(module_args_with_sweeps)
    sweep_args_dict = dict(
        filter(lambda kv: kv[0] in sweep_args, module_args_dict.items())
    )

    module_args_dict_list = [
        __merge_dicts(module_args_dict, s_dict)
        for s_dict in kw_product(sweep_args_dict)
    ]
    return [argparse.Namespace(**args_dict) for args_dict in module_args_dict_list]


def job_description_function(args: argparse.Namespace):
    """Job Description function generates a string description of the sweeped job parameters."""
    tokens = []
    for k, v in vars(args).items():
        token = f"{k}='{v}'" if isinstance(v, str) else f"{k}={v}"
        if k in sweep_args:
            tokens.append(token)
    return f'{", ".join(tokens)}' if tokens else "---"


class UnSupportedPythonModuleError(Exception):
    def __init__(self, module_name: str, missing_attr: str):
        super().__init__(
            f"Module '{module_name}' is not supported by submititnow. '{missing_attr}'"
            " is missing.\nTarget script must have two functions: \n\t* 'main(args:"
            " argparser.Namespace)'\n\t* 'add_arguments(parser = None) ->"
            " argparse.ArgumentParser'."
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("src_file", help="Source file to execute")

    options.add_submititnow_arguments(parser)
    options.add_slurm_arguments(parser)

    parser.add_argument(
        "--sweep", nargs="*", default=[], help="List of parameters to sweep over."
    )

    parser.add_argument(
        "--silent",
        action="store_true",
        help="Boolean flag to run experiments in silent mode.",
    )

    parser.add_argument(
        "--wait_until",
        default="submitted",
        choices=["none", "submitted", "running", "done"],
        help="Wait until the job is in the specified state before returning.",
    )

    args, downstream_args = parser.parse_known_args()

    target_module_name = get_module_name(args.src_file)
    exp_name = args.exp_name or target_module_name
    target_module = importlib.import_module(target_module_name)

    try:
        module_argparser = target_module.add_arguments()
    except AttributeError as e:
        raise UnSupportedPythonModuleError(target_module_name, "add_arguments") from e

    try:
        module_main_func = target_module.main
    except AttributeError as e:
        raise UnSupportedPythonModuleError(target_module_name, "main") from e

    sweep_args = {*args.sweep}
    module_args_list = create_module_args_list(
        module_argparser, sweep_args, downstream_args
    )

    experiment = experiment_lib.Experiment(
        exp_name,
        job_func=module_main_func,
        job_params=module_args_list,
        job_desc_function=job_description_function,
        submititnow_dir=args.submititnow_dir,
    )
    for name, handler in handlers.profile_handlers.items():
        experiment.register_profile_handler(name, handler)

    slurm_params = options.get_slurm_params(args)

    experiment.launch(slurm_params, verbose=not args.silent, wait_until=args.wait_until)
