#!python

# imports
from datetime import timedelta
from getdist import MCSamples
import numpy as np
import arviz as az
import argparse
import pandas
import stan
import h5py
import time
import os

# simplifiedmc imports
from simplifiedmc.stan import load, save, getsteps, getflatsamples, timeseries, runlog
from simplifiedmc.shared import corner, syslog, CIs


# main
def main(args):
    print("[*] Fetching arguments")
    # get the configutation from the configuration .yml file and/or CLI
    model, data, yml, names, labels, initial, markers, samples, warmup, chains, output, savechain, gzip, lzf, noshow, init, ndim = load(args)

    # prepare the output folder
    if output:
        print("\n[*] Setting up the output folder")
        if output[-1] == "/":
            output == output[:-1]
        os.system(f"mkdir {output}")
        os.system(f"mkdir {output}/plot")
        os.system(f"mkdir {output}/archive")
        os.system(f"mkdir {output}/archive/data")

    # import the model from the model .stan file
    print("\n[*] Importing the model")
    with open(model, "r") as file:
        program = file.read()

    # get input data from the provided .csv file(s)
    print("\n[*] Importing input data")
    dic = {}
    aux = {0: None}
    i = 0
    for file in data:
        header = pandas.read_csv(file, comment="#", nrows=0).columns.tolist()
        columns = pandas.read_csv(file, comment="#")

        stack = False
        for j in range(0, i+1):
            if header == aux[j]:
                stack = True
                break

        if stack:
            dic[f"N{j}"] = dic[f"N{j}"] + len(columns[header[0]])
        else:
            i += 1
            aux[i] = header
            dic[f"N{i}"] = len(columns[header[0]])

        for var in header:
            if var not in dic.keys():
                dic[var] = np.array(columns[var])
            else:
                dic[var] = np.append(dic[var], np.array(columns[var]))

    # run the sampler
    print("\n[*] Running stan")
    timestart = time.time()
    posterior = stan.build(program, data=dic)
    fit = posterior.sample(num_chains=chains, num_samples=samples, num_warmup=warmup, init=init, save_warmup=True)
    timeend = time.time()

    # compute execution time
    timeelapsed = timedelta(seconds = round(timeend - timestart))

    # convert fit to a numpy array of size [steps, chains, ndim], with all of the computed steps
    totalsteps = getsteps(fit, names, samples, warmup, chains, ndim)

    # flatten total steps (remove chain information) and remove warmup to a numpy array of size [steps, ndim]
    # then convert it to an MCSamples object for plotting in the corner plot
    flatsamples = getflatsamples(samples, warmup, chains, ndim, totalsteps)
    mcsamples = MCSamples(samples=flatsamples, names=names, labels=labels)

    print("\n[*] Plotting routines")

    # plot the time series
    outputtimeseries = f"{output}/plot/time-series.png" if output else None
    timeseries(totalsteps, names, labels, markers, samples, warmup, chains, ndim, output=outputtimeseries, noshow=noshow)

    # corner plot
    outputcorner = f"{output}/plot/corner.png" if output else None
    corner(mcsamples, markers, output=outputcorner, noshow=noshow)

    # send relevant information to output folder
    if output:
        print("\n[*] Saving information to output folder")

        # save configuration arguments to file
        yml = f"{output}/archive/config.yml"
        outputyml = f"{output}/archive/output.yml"
        save(yml, names, labels, initial, markers, samples, warmup, chains, outputyml, output, savechain, gzip, lzf, noshow)

        # save the model file
        os.system(f"cp {model} {output}/archive/model.stan")

        # save the data file(s)
        for file in data:
            os.system(f"cp {file} {output}/archive/data")

        # save the chain
        if savechain:
            with h5py.File(f"{output}/chain.h5", "w") as file:
                if gzip:
                    dataset = file.create_dataset("chain", data=totalsteps, compression="gzip", compression_opts=gzip, dtype="f")
                elif lzf:
                    dataset = file.create_dataset("chain", data=totalsteps, compression="lzf", dtype="f")
                else:
                    dataset = file.create_dataset("chain", data=totalsteps, dtype="f")

        # save arviz summary
        summary = az.summary(fit)
        summary.to_csv(path_or_buf=f"{output}/archive/summary.csv")

        # save 1 and 2 sigma regions in a latex table
        CIs(mcsamples, file=f"{output}/CIs.tex")

        # save run information
        runlog(timeelapsed, file=f"{output}/archive/run.log")

        # save machine information
        syslog(file=f"{output}/archive/sys.log")

    # otherwise print relevant content to screen
    else:
        print("\n[*] Confidence intervals")
        CIs(mcsamples)

        print("\n[*] Short sampler summary")
        summary = az.summary(fit, var_names=names)
        print(summary[["mean", "sd", "ess_bulk", "ess_tail", "r_hat"]])

        print("\n[*] Run information")
        runlog(timeelapsed)

    return


# run if called
if __name__ == "__main__":
    # create parser
    parser = argparse.ArgumentParser(description = "A CLI that simplifies the usage of the No-U-Turn sampler (NUTS), a variant of the Hamiltonian Monte Carlo (HMC), implemented in the Stan programming language.", add_help=False, epilog="Documentation, bug reports, suggestions and discussions at:\nhttps://github.com/jpmvferreira/simplifiedmc")

    # create argparser subgroups
    parser._action_groups.pop()
    required = parser.add_argument_group("Required arguments")
    configfile = parser.add_argument_group("Configuration file")
    config = parser.add_argument_group("Configuration arguments")
    output = parser.add_argument_group("Output arguments")
    help = parser.add_argument_group("Help dialog")

    # required arguments
    required.add_argument("-m", "--model", type=str, help="Input .stan statistical model.", required=True)
    required.add_argument("-d", "--data", nargs="*", help="Input data from one (or more) .csv file(s).", required=True)

    # configuration file
    configfile.add_argument("-y", "--yml", type=str, help="YAML file to configure the program/sampler behaviour. The arguments available in the next section, if provided, will overwrite the configuration options provided by this file.")

    # configuration arguments
    config.add_argument("-n", "--names", type=str, help="String with a Python like list with the names for each parameter, e.g.: \"['a', 'b']\". Must match the names defined in the .stan model file.")
    config.add_argument("-l", "--labels", type=str, help="A string with a Python like list with the labels for each parameter, e.g.: \"['\\alpha', '\\beta']\". Defaults to names.")
    config.add_argument("-i", "--initial", type=str, help="String with a Python style dictionary with the initial condition for each parameter, e.g.: \"['a': 'gauss(0, 1)', 'b': 'uniform(0, 1)']\".")
    config.add_argument("--markers", type=str, help="String with a Python style dictionary with the line markers to show rendered in the plots, e.g.: \"{'a': 0.5, 'b': 1.2}\". Defaults to none.")
    config.add_argument("-s", "--samples", type=int, help="Number of steps to sample the posterior distribution, after the warmup.")
    config.add_argument("-w", "--warmup", type=int, help="Number of steps to warmup each chain.")
    config.add_argument("-c", "--chains", type=int, help="Number of chains to run which will either run in parallel or sequentially, based on the number of available threads. Defaults to all available hardware threads.")

    # output arguments
    output.add_argument("-o", "--output", type=str, help="Output folder to save the results. Will show less information on screen, which will be saved to disk. Warning: will overwrite existing files.")
    output.add_argument("-sc", "--save-chain", action="store_true", help="Saves the samples to disk.")
    output.add_argument("-g", "--gzip", nargs="*", help="Compress the samples with GZIP. Optionally specify the compression level with an integer from 0 (fastest) to 9 (slowest). Default is 4. Good compression, moderate speed.")
    output.add_argument("--lzf", action="store_true", help="Compress the samples with LZF. Low to moderate compression, very fast.")
    output.add_argument("-ns", "--no-show", action="store_true", help="Don't show plots on screen.")

    # add help to its own subsection
    help.add_argument("-h", "--help", action="help", default=argparse.SUPPRESS, help="Show this help message and exit.")

    # get arguments
    args = parser.parse_args()

    # call main with the provided arguments
    main(args)
