#!python


# imports
from importlib import import_module
from multiprocessing import Pool
from datetime import timedelta
from getdist import MCSamples
from random import uniform
import numpy as np
import argparse
import pandas
import emcee
import time
import sys
import os

# simplifiedmc imports
from simplifiedmc.emcee import load, save, autocorrelation, timeseries, runlog
from simplifiedmc.shared import corner, syslog, CIs


# avoid numpy parallelization because it conflicts with emcee parallelization
os.environ["OMP_NUM_THREADS"] = "1"


# main
def main(args):
    # get arguments from CLI and file
    print("[*] Fetching arguments")
    model, data, yml, names, labels, initial, markers, percentage, samples, check, maxsteps, walkers, processes, output, savechain, gzip, lzf, tmp, shm, thin, timeseries, noshow, noprogress, init, ndim = load(args)

    # set up the output folder
    if output:
        print("\n[*] Setting up the output folder")
        if output[-1] == "/":
            output = output[:-1]

        if os.path.isfile(output):
            raise Exception("Output directory is a file")

        if not os.path.isdir(output):
            os.system(f"mkdir {output}")

        os.system(f"mkdir {output}/archive")
        os.system(f"mkdir {output}/archive/data")
        os.system(f"mkdir {output}/plot")

    # prepare a temporary folder to import module
    taken = True
    while taken:
        rng = int(uniform(0, 1000000000000))
        if not os.path.exists(f"/tmp/ez-emcee-{rng}"):
            taken = False
            os.system(f"mkdir /tmp/ez-emcee-{rng}")

    # import model
    print("\n[*] Importing the model")
    os.system(f"cp {model} /tmp/ez-emcee-{rng}/model.py")
    sys.path.append(f"/tmp/ez-emcee-{rng}")
    module = import_module("model")
    ln_probability = module.ln_probability
    sys.path.pop()

    # get input data
    print("\n[*] Importing input data")
    arguments = ()
    for file in data:
        header = pandas.read_csv(file, comment="#", nrows=0).columns.tolist()
        columns = pandas.read_csv(file, comment="#")

        N = len(columns[header[0]])
        filedata = (columns[header[0]], )
        for i in range(1, len(header)):
            filedata += (columns[header[i]], )

        arguments += filedata

    # set up the backend for emcee
    print("\n[*] Setting up the backend for emcee")
    if savechain:
        if tmp:
            if gzip:
                backend = emcee.backends.HDFBackend(f"/tmp/ez-emcee-{rng}/chain.h5", compression="gzip", compression_opts=gzip)
            elif lzf:
                backend = emcee.backends.HDFBackend(f"/tmp/ez-emcee-{rng}/chain.h5", compression="lzf")
            else:
                backend = emcee.backends.HDFBackend(f"/tmp/ez-emcee-{rng}/chain.h5")
        elif shm:
            os.system(f"mkdir /dev/shm/ez-emcee-{rng}")
            if gzip:
                backend = emcee.backends.HDFBackend(f"/dev/shm/ez-emcee-{rng}/chain.h5", compression="gzip", compression_opts=gzip)
            elif lzf:
                backend = emcee.backends.HDFBackend(f"/dev/shm/ez-emcee-{rng}/chain.h5", compression="lzf")
            else:
                backend = emcee.backends.HDFBackend(f"/dev/shm/ez-emcee-{rng}/chain.h5")
        else:
            backend = emcee.backends.HDFBackend(f"{output}/chain.h5")

        backend.reset(walkers, ndim)

    else:
        backend = None

    # auxiliary variables to compute the autocorrelation time in emcee
    index = 0
    correlation = np.empty(maxsteps)
    old_tau = np.inf

    # run emcee
    print("\n[*] Running emcee")
    timestart = time.time()

    with Pool(processes=processes) as pool:
        sampler = emcee.EnsembleSampler(walkers, ndim, ln_probability, args=arguments, pool=pool, backend=backend)

        # sample up to maxsteps
        for sample in sampler.sample(init, iterations=maxsteps, progress=noprogress):
            if not (sampler.iteration % check):
                # check for autocorrelation time
                tau = sampler.get_autocorr_time(tol=0) # tol=0 gives an estimate even if it isnt trustworthy
                correlation[index] = np.mean(tau)  # average across all dimensions
                index += 1

                # check for convergence
                if index > 1:
                    delta = correlation[index-1] * percentage
                    i = index - 2
                    converged = 0
                    while correlation[index-1]-delta < correlation[i] < correlation[index-1]+delta and i >= 0:
                        converged += check
                        i -= 1
                    if converged > samples:
                        break

                old_tau = tau

    timeend = time.time()

    # compute execution time
    timeelapsed = timedelta(seconds = round(timeend - timestart))

    # compute the number of discarded samples
    discard = sampler.iteration - samples

    # get samples in different formats
    if thin:
        flatsamples = sampler.get_chain(discard=discard, flat=True, thin=int(thin*correlation[index-1]))
    else:
        flatsamples = sampler.get_chain(discard=discard, flat=True)
    steps = sampler.get_chain()
    mcsamples = MCSamples(samples=flatsamples, names=names, labels=labels)

    print("\n[*] Plotting routines")

    # autocorrelation plot
    outputautocorrelation = f"{output}/plot/autocorrelation.png" if output else None
    autocorrelation(correlation, samples, check, index, sampler.iteration, delta, output=outputautocorrelation, noshow=noshow)

    # plot time series
    if timeseries:
        outputtimeseries = f"{output}/plot/time-series.png" if output else None
        timeseries(steps, labels, ndim, discard, output=outputtimeseries, noshow=noshow)

    # corner plot
    outputcorner = f"{output}/plot/corner.png" if output else None
    corner(mcsamples, markers, output=outputcorner, noshow=noshow)

    # send relevant information to output folder
    if output:
        print("\n[*] Saving information to output folder")

        # save configuration and output arguments to files
        yml = f"{output}/archive/config.yml"
        outputyml = f"{output}/archive/output.yml"
        save(yml, names, labels, initial, markers, percentage, samples, check, maxsteps, walkers, processes, outputyml, output, savechain, gzip, lzf, tmp, shm, thin, timeseries, noshow, noprogress)

        # save the model file
        os.system(f"cp {model} {output}/archive/model.py")

        # save the data file(s)
        for file in data:
            os.system(f"cp {file} {output}/archive/data")

        # move chain.h5 from /tmp to output, if present
        if tmp:
            os.system(f"mv /tmp/ez-emcee-{rng}/chain.h5 {output}/chain.h5")

        # move chain.h5 from /dev/shm to output, if present
        if shm:
            os.system(f"mv /dev/shm/ez-emcee-{rng}/chain.h5 {output}/chain.h5")

        # save 1 and 2 sigma regions in a latex table
        CIs(mcsamples, file=f"{output}/CIs.tex")

        # save machine information
        syslog(file=f"{output}/archive/sys.log")

        # save run information
        runlog(timeelapsed, samples, discard, converged, file=f"{output}/archive/run.log")

    # otherwise print relevant content to screen
    else:
        print("\n[*] Confidence intervals")
        CIs(mcsamples)

        print("\n[*] Run information")
        runlog(timeelapsed, samples, discard, converged)

    # clear temporary folder(s)
    os.system(f"rm -r /tmp/ez-emcee-{rng}")
    if shm:
        os.system(f"rm -r /dev/shm/ez-emcee-{rng}")

    return


# run if called
if __name__ == "__main__":
    # create parser
    parser = argparse.ArgumentParser(description = "A CLI simplifies the usage of the Affine Invariant MCMC Ensemble sampler, implemented in the emcee Python package.", add_help=False, epilog="Documentation, bug reports, suggestions and discussions at:\nhttps://github.com/jpmvferreira/simplifiedmc")

    # create argparser subgroups
    parser._action_groups.pop()
    required = parser.add_argument_group("Required arguments")
    configfile = parser.add_argument_group("Configuration file")
    config = parser.add_argument_group("Configuration arguments")
    output = parser.add_argument_group("Output arguments")
    help = parser.add_argument_group("Help dialog")

    # required arguments
    required.add_argument("-m", "--model", type=str, help="Python file that defines the model probability.", required=True)
    required.add_argument("-d", "--data", nargs="*", help="Input data from one (or more) .csv file(s).", required=True)

    # configuration file
    configfile.add_argument("-y", "--yml", type=str, help="YAML file to configure the program/sampler behaviour. The arguments available in the next section, if provided, will overwrite the configuration options provided by this file.")

    # configuration arguments
    config.add_argument("-n", "--names", type=str, help="String with a Python like list with the names for each parameter, e.g.: \"['a', 'b']\".")
    config.add_argument("-l", "--labels", type=str, help="A string with a Python like list with the labels for each parameter, e.g.: \"['\\alpha', '\\beta']\".")
    config.add_argument("-i", "--initial", type=str, help="String with a Python style dictionary with the initial condition for each parameter, e.g.: \"['a': 'gauss(0, 1)', 'b': 'uniform(0, 1)']\".")
    config.add_argument("--markers", type=str, help="String with a Python style dictionary with the line markers to show rendered in the plots, e.g.: \"{'a': 0.5, 'b': 1.2}\".")
    config.add_argument("-p", "--percentage", type=float, help="Autocorrelation time must change less than this percentage to consider convergence.")
    config.add_argument("-s", "--samples", type=int, help="Number of steps to compute when convergence it met.")
    config.add_argument("-c", "--check", type=int, help="Number of steps to check if convergence is met. Defaults to 1000 steps")
    config.add_argument("-M", "--maxsteps", type=int, help="The maximum number of steps, regardless of the convergence status. Defaults to 100000 steps.")
    config.add_argument("-w", "--walkers", type=int, help="The number of walkers to sample the likelihood. Defaults to 32 walkers.")
    config.add_argument("--processes", type=int, help="The number of processes to spawn to make use of multi core processors. Defaults to all available hardware threads.")

    # output arguments
    output.add_argument("-o", "--output", type=str, help="Output folder. Warning: will overwrite existing files.")
    output.add_argument("-sc", "--save-chain", action="store_true", help="Saves the computed steps to disk. Warning: IO heavy operation.")
    output.add_argument("-g", "--gzip", type=int, help="Compress the samples with GZIP. Optionally specify the compression level with an integer from 0 (fast) to 9 (slow). Recommended value is 4. Good compression, moderate speed.")
    output.add_argument("--lzf", action="store_true", help="Compress the samples with LZF. Low to moderate compression, very fast.")
    output.add_argument("--tmp", action="store_true", help="Temporarily store the computed steps in /tmp, which usually uses tmpfs, and writes it to disk when done. Warning: Steps will be lost on system failure and might require a lot of memory.")
    output.add_argument("--shm", action="store_true", help="Same as --tmp but for /dev/shm, for systems who do not have /tmp mounted as tmpfs.")
    output.add_argument("--thin", type=float, help="Thin samples by the closest integer number of steps corresponding to the user provided fraction multiplied by the autocorrelation time. This only affects the corner plot, not the information stored on disk.")
    output.add_argument("-t", "--time-series", action="store_true", help="Show or save the time series for each parameter. Warning: Memory intensive on longer runs.")
    output.add_argument("-ns", "--no-show", action="store_true", help="Don't show plots on screen.")
    output.add_argument("-np", "--no-progress", action="store_false", help="Don't show the emcee progress bar.")

    # add help to its own subsection
    help.add_argument("-h", "--help", action="help", default=argparse.SUPPRESS, help="Show this help message and exit.")

    # get arguments
    args = parser.parse_args()

    # run main
    main(args)
