#!python


# imports
from getdist import plots, MCSamples
import matplotlib.pyplot as plt
import numpy as np
import argparse
import emcee
import yaml
import h5py
import os

# simplifiedmc imports
from simplifiedmc.shared import corner
from simplifiedmc.stan import getflatsamples


# main
def main(args):
    # get arguments
    input = args.input
    markers = eval(args.markers) if args.markers else {}
    labels = eval(args.labels) if args.labels else None
    legend = eval(args.legend) if args.legend else None
    contour_alpha = args.contour_alpha
    filled_alpha = args.filled_alpha
    output = args.output
    noshow = args.no_show

    # check if sizes match
    if legend:
        if len(legend) != len(input):
            raise Exception(f"Missmatch between the length of input folders (len = {len(input)}) and the number of labels for the legend provided (len = {len(legend)})")

    # list to hold all mcsamples for plotting
    mcsamples = []

    # iterate over all input folders
    index = -1
    for folder in input:
        index += 1

        # remove trailing "/" on folder
        if folder[-1] == "/":
            folder = folder[:-1]

        # check if legend was provided or not
        if legend:
            label = legend[index]
        else:
            label = folder.split("/")[-1]

        # open chain.h5
        h5file = h5py.File(f"{folder}/chain.h5", "r")

        # if output is from ez-stan
        if os.path.isfile(f"{folder}/archive/model.stan"):
            # get arguments from config.yml
            with open(f"{folder}/archive/config.yml", "r") as file:
                yml_loaded = yaml.full_load(file)
                names = yml_loaded["names"]
                labels = labels if labels else yml_loaded["labels"]
                samples = yml_loaded["samples"]
                warmup = yml_loaded["warmup"]
                chains = yml_loaded["chains"]

            # get the samples, flatten them (drop chain information), and convert it to an MCSamples object
            flatsamples = getflatsamples(samples, warmup, chains, len(names), h5file["chain"])
            mcsamples.append(MCSamples(samples=flatsamples, names=names, labels=labels, label=label))

        # if output is from ez-emcee
        elif os.path.isfile(f"{folder}/archive/model.py"):
            # get arguments from config.yml
            with open(f"{folder}/archive/config.yml", "r") as file:
                yml_loaded = yaml.full_load(file)
                names = yml_loaded["names"]
                labels = labels if labels else yml_loaded["labels"]

            # get number of discarded steps from run.log
            with open(f"{folder}/archive/run.log", "r") as file:
                yml_loaded = yaml.full_load(file)
                discard = yml_loaded["discard"]

            # get the sampler using an emcee backend
            sampler = emcee.backends.HDFBackend(f"{folder}/chain.h5", read_only=True)

            # get the samples from the sampler
            flatsamples = sampler.get_chain(discard=discard, flat=True)
            mcsamples.append(MCSamples(samples=flatsamples, names=names, labels=labels, label=label))

        # output is unknown
        else:
            raise Exception(f"Could not tell if output in {folder} is from ez-emcee or ez-stan, have you modified or deleted any files?")

    # set marker no None if not provided
    for name in names:
        try:
            markers[name]
        except KeyError:
            markers[name] = None

    # show and/or output corner plot
    corner(mcsamples, markers, output=output, noshow=noshow, filled_alpha=filled_alpha, contour_alpha=contour_alpha)

    return


if __name__ == "__main__":
    # create parser
    parser = argparse.ArgumentParser(description = "A CLI to overlap the likelihoods produced by the samplers available on this package, using the getdist Python package.", add_help=False, epilog="Documentation, bug reports, suggestions and discussions at:\nhttps://github.com/jpmvferreira/simplifiedmc")

    # create argparser subgroups
    parser._action_groups.pop()
    required = parser.add_argument_group("Required arguments")
    config = parser.add_argument_group("Configuration arguments")
    output = parser.add_argument_group("Output arguments")
    help = parser.add_argument_group("Help dialog")

    # required arguments
    required.add_argument("-i", "--input", nargs="*", help="The input folder(s) that contain a 'chain.h5' file. Assumes that all runs share the name parameter space, with the same names.", required=True)

    # configuration arguments
    config.add_argument("-l", "--labels", type=str, help="A string with a Python like list with the labels for each parameter, e.g.: \"['\\alpha', '\\beta']\".")
    config.add_argument("-m", "--markers", type=str, help="String with a Python style dictionary with the line markers to show rendered in the plots, e.g.: \"{'a': 0.5, 'b': 1.2}\". Defaults to no markers.")
    config.add_argument("--legend", type=str, help="A string with a Python like list with the legend for each parameter, e.g.: \"['\\sample 0', '\\sample 1']\". Must match the order of the input files. Defaults to folder name.")
    config.add_argument("--filled-alpha", type=float, help="The filled region alpha setting.")
    config.add_argument("--contour-alpha", type=float, help="The contours alpha setting.", default=0.5)

    # output arguments
    output.add_argument("-o", "--output", type=str, help="Output file for the corner plot.")
    output.add_argument("-ns", "--no-show", action="store_true", help="Don't show corner plot on screen.")

    # add help to its own subsection
    help.add_argument("-h", "--help", action="help", default=argparse.SUPPRESS, help="Show this help message and exit.")

    # get arguments
    args = parser.parse_args()

    # call main with the provided arguments
    main(args)
