#!usr/bin/env python
# -*- coding: utf-8 -*-

import os
import argparse
import ast
from pathlib import Path
import pandas as pd
from io import StringIO

import seaborn as sns
import matplotlib.pyplot as plt

from GANDLF.cli import copyrightMessage


def main():
    parser = argparse.ArgumentParser(
        prog="GANDLF_CollectStats",
        formatter_class=argparse.RawTextHelpFormatter,
        description="Collect statistics from different testing/validation combinations from output directory.\n\n"
        + copyrightMessage,
    )
    parser.add_argument(
        "-m",
        "--modeldir",
        metavar="",
        type=str,
        help="Input directory which contains testing and validation models",
    )
    parser.add_argument(
        "-o",
        "--outputdir",
        metavar="",
        type=str,
        help="Output directory to save stats and plot",
    )
    parser.add_argument(
        "-c",
        "--combinedplots",
        metavar="",
        default=False,
        type=ast.literal_eval,
        help="Overlays training and validation plots for both accuracy and loss (classification only).",
    )

    args = parser.parse_args()

    inputDir = os.path.normpath(args.modeldir)
    outputDir = os.path.normpath(args.outputdir)
    Path(outputDir).mkdir(parents=True, exist_ok=True)
    outputFile = os.path.join(outputDir, "data.csv")  # data file name
    outputPlot = os.path.join(outputDir, "plot.png")  # plot file

    combinedPlots = args.combinedplots

    trainingLogs = os.path.join(inputDir, "logs_training.csv")
    validationLogs = os.path.join(inputDir, "logs_validation.csv")
    testingLogs = os.path.join(inputDir, "logs_testing.csv")

    if os.path.exists(testingLogs):
        testingLogsCSV = pd.read_csv(testingLogs)

        # check for classification task
        if len(testingLogsCSV) == 0:
            print("Classification task detected, generating accuracy and loss plots.")

            # check whether user wants training + validation overlaid plots
            if combinedPlots:
                df_training = pd.read_csv(trainingLogs)
                df_validation = pd.read_csv(validationLogs)

                epochs = len(df_training)

                fig, axes = plt.subplots(nrows=1, ncols=2)  # set plot properties
                # ensure spacing between plots
                plt.subplots_adjust(wspace=0.5, hspace=0.5)
                # plot training accuracy data
                splot = sns.lineplot(
                    data=df_training,
                    x="epoch_no",
                    y="train_balanced_accuracy",
                    ax=axes[0],
                )
                # plot validation accuracy data
                splot = sns.lineplot(
                    data=df_validation,
                    x="epoch_no",
                    y="valid_balanced_accuracy",
                    ax=axes[0],
                )
                # set limits for x-axis for proper visualization
                splot.set(xlim=(0, epochs - 1))
                # set limits for y-axis for proper visualization
                splot.set(ylim=(0, 1))
                # add labels and title to plot
                splot.set(xlabel="Epoch", ylabel="Accuracy", title="Accuracy Plot")
                # add legend to plot
                axes[0].legend(labels=["Training", "Validation"])

                # plot training loss data
                splot = sns.lineplot(
                    data=df_training, x="epoch_no", y="train_loss", ax=axes[1]
                )
                # plot validation loss data
                splot = sns.lineplot(
                    data=df_validation, x="epoch_no", y="valid_loss", ax=axes[1]
                )
                # set limits for x-axis for proper visualization
                splot.set(xlim=(0, epochs - 1))
                # add labels and title to plot
                splot.set(xlabel="Epoch", ylabel="Loss", title="Loss Plot")
                # add legend to plot
                axes[1].legend(labels=["Training", "Validation"])
                # save plot
                plt.savefig(outputPlot, dpi=600)

                print("Plots saved successfully.")

            else:
                df_training = pd.read_csv(trainingLogs)
                df_validation = pd.read_csv(validationLogs)

                epochs = len(df_training)

                # set plot properties
                fig, axes = plt.subplots(nrows=2, ncols=2)

                plt.subplots_adjust(wspace=0.5, hspace=0.5)
                # plot the data
                splot = sns.lineplot(
                    data=df_training,
                    x="epoch_no",
                    y="train_balanced_accuracy",
                    ax=axes[0, 0],
                )
                splot.set(xlim=(0, epochs - 1))
                splot.set(ylim=(0, 1))  # set limits for y-axis for proper visualization
                # set labels
                splot.set(
                    xlabel="Epoch", ylabel="Accuracy", title="Training Accuracy Plot"
                )

                # plot the data
                splot = sns.lineplot(
                    data=df_validation,
                    x="epoch_no",
                    y="valid_balanced_accuracy",
                    ax=axes[0, 1],
                )
                splot.set(xlim=(0, epochs - 1))
                splot.set(ylim=(0, 1))  # set limits for y-axis for proper visualization
                # set labels
                splot.set(
                    xlabel="Epoch", ylabel="Accuracy", title="Validation Accuracy Plot"
                )
                # plot the data
                splot = sns.lineplot(
                    data=df_training, x="epoch_no", y="train_loss", ax=axes[1, 0]
                )
                splot.set(xlim=(0, epochs - 1))
                # set labels
                splot.set(xlabel="Epoch", ylabel="Loss", title="Training Loss Plot")
                # plot the data
                splot = sns.lineplot(
                    data=df_validation, x="epoch_no", y="valid_loss", ax=axes[1, 1]
                )
                splot.set(xlim=(0, epochs - 1))
                # set labels
                splot.set(xlabel="Epoch", ylabel="Loss", title="Validation Loss Plot")

                plt.savefig(outputPlot, dpi=600)

                print("Plots saved successfully.")

        else:
            print("Segmentation task detected, generating dice and loss plots.")

            final_stats = "Epoch,Train_Loss,Train_Dice,Val_Loss,Val_Dice,Testing_Loss,Testing_Dice\n"  # the columns that need to be present in final output; epoch is always removed

            # loop through output directory
            for dirs in os.listdir(inputDir):
                currentTestingDir = os.path.join(inputDir, dirs)
                if os.path.isdir(currentTestingDir):  # go in only if it is a directory
                    if "testing_" in dirs:  # ensure it is part of the testing structure
                        # loop through all validation directories
                        for val in os.listdir(currentTestingDir):
                            currentValidationDir = os.path.join(currentTestingDir, val)
                            if os.path.isdir(currentValidationDir):
                                # get all files in each directory
                                filesInDir = os.listdir(currentValidationDir)

                                for i, n in enumerate(filesInDir):
                                    # when the log has been found, collect the final numbers
                                    if "trainingScores_log" in n:
                                        log_file = os.path.join(currentValidationDir, n)
                                        with open(log_file) as f:
                                            for line in f:
                                                pass
                                            final_stats = final_stats + line

            data_string = StringIO(final_stats)
            data_full = pd.read_csv(data_string, sep=",")
            del data_full["Epoch"]  # no need for epoch
            data_full.to_csv(outputFile, index=False)  # save updated data

            # perform deep copy
            data_loss = data_full.copy()
            data_dice = data_full.copy()
            # set the datasets that need to be plotted
            cols = [
                "Train",
                "Val",
                "Testing",
            ]
            for i in cols:
                del data_dice[i + "_Loss"]  # keep only dice
                del data_loss[i + "_Dice"]  # keep only loss
                # rename the columns
                data_loss.rename(columns={i + "_Loss": i}, inplace=True)
                # rename the columns
                data_dice.rename(columns={i + "_Dice": i}, inplace=True)
            # set plot properties
            fig, axes = plt.subplots(nrows=1, ncols=2, constrained_layout=True)
            # plot the data
            bplot = sns.boxplot(
                data=data_dice, width=0.5, palette="colorblind", ax=axes[0]
            )
            # set limits for y-axis for proper visualization
            bplot.set(ylim=(0, 1))
            # set labels
            bplot.set(xlabel="Dataset", ylabel="Dice", title="Dice plot")
            # rotate so that everything is visible
            bplot.set_xticklabels(bplot.get_xticklabels(), rotation=15, ha="right")
            # plot the data
            bplot = sns.boxplot(
                data=data_loss, width=0.5, palette="colorblind", ax=axes[1]
            )
            # set limits for y-axis for proper visualization
            bplot.set(ylim=(0, 1))
            # set labels
            bplot.set(xlabel="Dataset", ylabel="Loss", title="Loss plot")
            # rotate so that everything is visible
            bplot.set_xticklabels(bplot.get_xticklabels(), rotation=15, ha="right")

            plt.savefig(outputPlot, dpi=600)

            print("Plots saved successfully.")


# main function
if __name__ == "__main__":
    main()
