#!usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import print_function, division
import os
import argparse
import ast
import sys

from GANDLF import version
from GANDLF.cli import main_run, copyrightMessage


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="GANDLF",
        formatter_class=argparse.RawTextHelpFormatter,
        description="Semantic segmentation, regression, and classification for medical images using Deep Learning.\n\n"
        + copyrightMessage,
    )
    parser.add_argument(
        "-c",
        "--config",
        "--parameters_file",
        metavar="",
        type=str,
        required=True,
        help="The configuration file (contains all the information related to the training/inference session)",
    )
    parser.add_argument(
        "-i",
        "--inputdata",
        "--data_path",
        metavar="",
        type=str,
        required=True,
        help="Data CSV file that is used for training/inference; can also take comma-separated training-validation pre-split CSVs",
    )
    parser.add_argument(
        "-t",
        "--train",
        metavar="",
        type=ast.literal_eval,
        required=True,
        help="True: training and False: inference; for inference, there needs to be a compatible model saved in '-modeldir'",
    )
    parser.add_argument(
        "-m",
        "--modeldir",
        metavar="",
        type=str,
        help="Training: Output directory to save intermediate files and model weights; inference: location of previous training session output",
    )
    parser.add_argument(
        "-d",
        "--device",
        default="cuda",
        metavar="",
        type=str,
        required=True,
        help="Device to perform requested session on 'cpu' or 'cuda'; for cuda, ensure CUDA_VISIBLE_DEVICES env var is set",
    )
    parser.add_argument(
        "-rt",
        "--reset",
        metavar="",
        default=False,
        type=ast.literal_eval,
        help="Completely resets the previous run by deleting 'modeldir'",
    )
    parser.add_argument(
        "-rm",
        "--resume",
        metavar="",
        default=False,
        type=ast.literal_eval,
        help="Resume previous training by only keeping model dict in 'modeldir'",
    )

    parser.add_argument(
        "-o",
        "--outputdir",
        "--output_path",
        metavar="",
        type=str,
        help="Location to save the output of the inference session. Not used for training.",
    )

    parser.add_argument(
        "-v",
        "--version",
        action="version",
        version="%(prog)s v{}".format(version) + "\n\n" + copyrightMessage,
        help="Show program's version number and exit.",
    )

    # This is a dummy argument that exists to trigger MLCube mounting requirements.
    # Do not remove.
    parser.add_argument("-rawinput", "--rawinput", help=argparse.SUPPRESS)

    args = parser.parse_args()
    if args.modeldir is None and args.outputdir:
        args.modeldir = args.outputdir

    assert args.modeldir is not None, "Missing required parameter: modeldir"

    if os.path.isdir(args.inputdata):
        # Is this a fine assumption to make?
        # Medperf models receive the data generated by the data preparator mlcube
        # We can therefore ensure the output of that mlcube contains a data.csv file
        filename = "data.csv"
        args.inputdata = os.path.join(args.inputdata, filename)

    if not args.train:
        # if inference mode, then no need to check for reset/resume
        args.reset, args.resume = False, False

    if args.reset and args.resume:
        print(
            "WARNING: 'reset' and 'resume' are mutually exclusive; 'resume' will be used."
        )
        args.reset = False

    # config file should always be present
    if not (os.path.isfile(args.config)):
        sys.exit("ERROR: Configuration file not found!")

    try:
        main_run(
            args.inputdata,
            args.config,
            args.modeldir,
            args.train,
            args.device,
            args.resume,
            args.reset,
            args.outputdir,
        )
    except Exception as e:
        sys.exit("ERROR: " + str(e))

    print("Finished.")
