#!python

from __future__ import unicode_literals

import os
import glob

import numpy as np
from pandas.plotting import autocorrelation_plot

import trilearn.auxiliary_functions as aux
from trilearn.graph import trajectory

# matplotlib.rcParams['text.usetex'] = True
# matplotlib.rcParams['text.latex.unicode'] = True
import matplotlib.pyplot as plt
# matplotlib.rc('xtick', labelsize=8)
# matplotlib.rc('ytick', labelsize=8)
# np.set_printoptions(precision=2)


def main(burnin_end, input_directory, output_directory):

    if not os.path.exists(output_directory):
        os.mkdir(output_directory)

    trajectories = {}
    for filename in glob.glob(input_directory + "/*.json"):
        # Gather all with the same parameter setting in the same plot
        t = trajectory.Trajectory()
        print filename
        t.read_file(filename)
        if str(t) not in trajectories:
            trajectories[str(t)] = []
        trajectories[str(t)].append(t)


    for param_setting, traj_list in trajectories.iteritems():
        print traj_list[0].sampling_method
        print "Average sample time: " + str(np.mean(traj_list[0].time))

        for t in traj_list:
            t.size(burnin_end).plot()
        plt.title(str(t.sampling_method))
        plt.savefig(output_directory + str(t) + "_size.png")
        plt.clf()

        for t in traj_list:
            t.log_likelihood(burnin_end).plot()
        plt.title(str(t.sampling_method))
        plt.savefig(output_directory + str(t) + "_log-likelihood.png")
        plt.clf()

        for i, t in enumerate(traj_list):
            aux.plot_heatmap(t.empirical_distribution(burnin_end).heatmap())
            plt.title(str(traj_list[0].sampling_method))
            plt.savefig(output_directory + str(t) + "_heatmap_"+str(i)+".png")
            plt.clf()

        for i, t in enumerate(traj_list):
            autocorrelation_plot(t.size(burnin_end))
            plt.title(str(t.sampling_method))
            plt.savefig(output_directory + str(t) + "_size_autocorr_"+str(i)+".png")
            plt.clf()


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser("Generates analytics for the Markov chain of decomposable graphs generated "
                                     "by particle Gibbs.")

    parser.add_argument(
        '-o', '--output_directory',
        required=False, default="./",
        help="Output directory")
    parser.add_argument(
        '-i', '--input_directory',
        required=False, default="./",
        help="Input directory")
    parser.add_argument(
        '--burnin_end', type=int, required=False, default=0,
        help="Burn-in period ends at this index")

    args = parser.parse_args()
    main(**args.__dict__)