#!/usr/bin/env python3

import os

import sys  # noqa


sys.path.append(os.getcwd())  # convenience: load implementations from cwd

pytorch_mp = os.environ.get("USE_PYTORCH_MP", "False") == "True"
if not pytorch_mp:
    import multiprocessing as mp

    # mp = _mp.set_start_method('spawn')
else:
    print("Using pytorch mp")
    import torch.multiprocessing as mp


import argparse  # noqa
import yaml  # noqa

from edflow.main import train, test  # noqa
from edflow.custom_logging import init_project, use_project, get_logger  # noqa
from edflow.custom_logging import set_global_stdout_level  # noqa
from edflow.hooks.checkpoint_hooks.common import get_latest_checkpoint  # noqa
from edflow.config import parse_unknown_args, update_config


def main(opt, additional_kwargs):
    # Project manager
    if opt.project is not None:
        P = use_project(opt.project, postfix=opt.name)
    else:
        # Get base or default parameters
        base_config = {}
        if opt.base is not None:
            for base in opt.base:
                print(base)
                with open(base) as f:
                    base_config.update(yaml.full_load(f))
        print(base_config)
        # get path to implementation
        if opt.train:
            with open(opt.train) as f:
                config = base_config
                config.update(yaml.full_load(f))
                update_config(config, additional_kwargs)
                impl = config["model"]
                name = config.get("experiment_name", None)
        else:
            config = base_config
            update_config(config, additional_kwargs)
            impl = config["model"]
            name = config.get("experiment_name", None)
        if "code_root" in config:
            code_root = config["code_root"]
        else:
            path = impl.split(".")
            # if it looks like a package path, take its root as the code dir
            # otherwise take cwd
            if len(path) > 0:
                code_root = path[0]
            else:
                code_root = "."

        if opt.name is not None:
            # command line takes precedence over "experiment_name" from
            # config
            name = opt.name
        P = init_project("logs", code_root=code_root, postfix=name)

    # Logger
    set_global_stdout_level(opt.log_level)
    logger = get_logger("main")
    logger.info(opt)
    logger.info(P)

    # Processes
    processes = list()

    # Error Communication between processes
    JQ = mp.Queue()

    # Training
    if opt.train:
        if opt.checkpoint is not None:
            checkpoint = opt.checkpoint
        elif opt.project is not None:
            checkpoint = get_latest_checkpoint(P.checkpoints)
        else:
            checkpoint = None

        base_config = {}
        if opt.base is not None:
            for base in opt.base:
                print(base)
                with open(base) as f:
                    base_config.update(yaml.full_load(f))
        print(base_config)
        # get path to implementation
        with open(opt.train) as f:
            config = base_config
            config.update(yaml.full_load(f))
            update_config(config, additional_kwargs)

        logger.info("Training config: {}\n{}".format(opt.train, yaml.dump(config)))

        train_process = mp.Process(
            target=train,
            args=((config, P.train, checkpoint, opt.retrain), JQ, len(processes)),
        )
        processes.append(train_process)

    # Evaluation
    opt.eval = opt.eval or list()
    for eval_idx, eval_config in enumerate(opt.eval):
        base_config = {}
        if opt.base is not None:
            for base in opt.base:
                with open(base) as f:
                    base_config.update(yaml.full_load(f))
        with open(eval_config) as f:
            config = base_config
            config.update(yaml.full_load(f))
            update_config(config, additional_kwargs)

        disable_eval_all = disable_eval_forever = False
        if opt.checkpoint is not None:
            checkpoint = opt.checkpoint
            disable_eval_all = disable_eval_forever = True
        elif opt.project is not None:
            if any([config.get("eval_all", False), config.get("eval_forever", False)]):
                checkpoint = None
            else:
                checkpoint = get_latest_checkpoint(P.checkpoints)
                disable_eval_all = disable_eval_forever = True
        else:
            checkpoint = None

        if disable_eval_all:
            config.update({"eval_all": False})
            logger.info(
                "{} was disabled because you specified a checkpoint.".format("eval_all")
            )

        if disable_eval_forever:
            config.update({"eval_forever": False})
            logger.info(
                "{} was disabled because you specified a checkpoint.".format(
                    "eval_forever"
                )
            )

        logger.info(
            "Evaluation config: {}\n{}".format(eval_config, yaml.dump(eval_config))
        )
        nogpu = len(processes) > 0 or opt.nogpu
        bar_position = len(processes) + eval_idx

        test_process = mp.Process(
            target=test,
            args=(
                (config, P.latest_eval, checkpoint, nogpu, bar_position),
                JQ,
                len(processes),
            ),
        )
        processes.append(test_process)

    # Take off
    try:
        for p in processes:
            p.start()
        logger.info("Started {} process(es).".format(len(processes)))

        done_count = 0
        while True:
            pidx, exc_or_done, trace = JQ.get()
            if isinstance(exc_or_done, Exception):
                logger.warn("Exception in process {}:".format(pidx))
                logger.warn(trace)
                raise exc_or_done
            elif exc_or_done == "Done":
                logger.info("Process {} is done".format(pidx))
                done_count += 1
            else:
                raise Exception("Unknown element on queue.")

            if done_count >= len(processes):
                break

    except (Exception, KeyboardInterrupt) as e:
        logger.info("Terminating all processes")
        for p in processes:
            p.terminate()
        if not isinstance(e, KeyboardInterrupt):
            # reraise if we caught a general exception
            # TODO: only reraise if this is not an exception of a child that
            # was already logged above
            raise e
    finally:
        # Landing
        for p in processes:
            p.join()
        logger.info("Finished")


if __name__ == "__main__":
    default_log_dir = os.path.join(os.getcwd(), "log")

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-n", "--name", metavar="description", help="postfix of log directory."
    )
    parser.add_argument(
        "-b",
        "--base",
        nargs="*",
        metavar="base_config.yaml",
        help="Path to base config. Any parameter in here is overwritten by "
        "the train of eval config. Useful e.g. for model parameters, which"
        " stay constant between trainings and evaluations.",
        default=None,
    )
    parser.add_argument(
        "-t", "--train", metavar="config.yaml", help="path to training config"
    )
    parser.add_argument(
        "-e",
        "--eval",
        nargs="*",
        metavar="config.yaml",
        help="path to evaluation configs",
    )
    parser.add_argument("-p", "--project", help="path to existing project")
    parser.add_argument("-c", "--checkpoint", help="path to existing checkpoint")
    parser.add_argument(
        "-r", "--retrain", action="store_true", help="reset global step"
    )
    parser.add_argument(
        "--nogpu", action="store_true", help="disable gpu for tensorflow"
    )
    parser.add_argument(
        "-log",
        "--log-level",
        metavar="LEVEL",
        type=str,
        choices=["warn", "info", "debug", "critical"],
        default="info",
        help="Set the std-out logging level.",
    )

    opt, unknown = parser.parse_known_args()
    additional_kwargs = parse_unknown_args(unknown)

    if "option" in additional_kwargs or "o" in additional_kwargs:
        msg = "\n" * 2
        msg += "=" * 20
        msg += (
            "\n\033[93m\033[1m\033[4mThe --option Argument is deprecated!"
            + "\033[0m\n"
            + "You kann now pass additional options via --key value "
            + "pairs. This example equivalent to the old --option `key: "
            + "value`.\n"
            + "Passing --option `sth` is still possible, but will result"
            + " in a config entry `config[option] = sth`.\n"
        )
        msg += "=" * 20
        msg += "\n" * 2
        print(msg)
    main(opt, additional_kwargs)
