#!/usr/bin/env python
# -*- coding: utf-8 -*-

import json
import subprocess
import argparse

from submititnow.umiacs.handlers import profile_handlers
from rich import print as rich_print

parser = argparse.ArgumentParser()
parser.add_argument("--config", default=[], action="append")
parser.add_argument("--job-name", default="dev")
parser.add_argument("shell", nargs="+", default=["bash"])


def removeprefix(var: str, prefix: str):
    return var[len(prefix) :] if var.startswith(prefix) else var


def load_config(config_filename: str):
    with open(config_filename) as f:
        config = json.load(f)
    if "profile" in config:
        profile = config.pop("profile")
        config = profile_handlers[profile](config)

    return {
        removeprefix(key, "slurm_").replace("_", "-"): value
        for key, value in config.items()
    }


def parse_arguments():
    """Returns command args for SRUN and the downstream command to execute on the execution node."""
    args, downstream_args = parser.parse_known_args()

    # Create args dict from downstream args
    cli_args = {}
    for arg in downstream_args:
        if arg.startswith("--"):
            key, value = arg[2:].split("=")
            cli_args[key] = value

    # Populate arguments from config files
    config_args = {}
    for config_filename in args.config:
        config_args = {**config_args, **load_config(config_filename)}

    # Merge arguments from config files and cli
    cmd_args = {**config_args, **cli_args, "job-name": args.job_name}

    return cmd_args, " ".join(args.shell)


def make_shell_command(cmd_args: dict, downstream_cmd: str):
    """Returns the full SRUN command to execute on the shell."""
    shell_cmd = "srun"
    for key, value in cmd_args.items():
        shell_cmd += f" --{key}={value}"
    shell_cmd += f" --pty {downstream_cmd}"
    return shell_cmd


if __name__ == "__main__":
    cmd_args, downstream_cmd = parse_arguments()
    cmd = make_shell_command(cmd_args, downstream_cmd)

    rich_print(f"[bold]{cmd}")

    subprocess.run(cmd, shell=True)
