#!python

from __future__ import unicode_literals

import json
import os
from os.path import basename
import glob

import matplotlib
import networkx as nx
import numpy as np
import pandas as pd
import seaborn as sns
from networkx.readwrite import json_graph
from pandas.plotting import autocorrelation_plot

import trilearn.auxiliary_functions as aux
import trilearn.graph.decomposable
import trilearn.graph.graph as glib
import trilearn.graph.junction_tree as jtlib
from trilearn.distributions import sequential_junction_tree_distributions as sjtd
from trilearn import pgibbs
from trilearn.graph import trajectory

matplotlib.rcParams['text.usetex'] = True
matplotlib.rcParams['text.latex.unicode'] = True
import matplotlib.pyplot as plt
matplotlib.rc('xtick', labelsize=8)
matplotlib.rc('ytick', labelsize=8)


np.set_printoptions(precision=2)


def main(data_filename, n_particles, trajectory_length, radius, alphas, betas,
         graphfile, precmat, burnin_end, input_directory, output_directory, reps):
    filename = basename(data_filename)
    data = os.path.splitext(filename)[0]
    X = np.matrix(np.loadtxt(data_filename, delimiter=','))
    sample_size = X.shape[0]
    p = X.shape[1]
    SS = X.T * X
    delta = 1.0
    D = np.identity(p)
    S = SS / sample_size
    radii = None
    if radius is None:
        radii = [p]
    else:
        radii = radius

    if not os.path.exists(output_directory):
        os.mkdir(output_directory)

    trajectories = {}
    for filename in glob.glob(input_directory + "/*.json"):
        # Gather all with the same parameter setting in the same plot
        t = trajectory.Trajectory()
        print filename
        t.read_file(filename)
        if str(t) not in trajectories:
            trajectories[str(t)] = []
        trajectories[str(t)].append(t)

    print trajectories
    for param_setting, traj_list in trajectories.iteritems():
        print traj_list[0].sampling_method
        print "Average sample time (for the first trajectory): " + str(np.mean(traj_list[0].time))

        for t in traj_list:
            t.size(burnin_end).plot()
        plt.title(str(t.sampling_method))
        plt.savefig(output_directory + str(t) + "_size.png")
        plt.clf()

        for t in traj_list:
            t.log_likelihood(burnin_end).plot()
        plt.title(str(t.sampling_method))
        plt.savefig(output_directory + str(t) + "_log-likelihood.png")
        plt.clf()

        aux.plot_heatmap(traj_list[0].empirical_distribution(burnin_end).heatmap())
        plt.title(str(traj_list[0].sampling_method))
        plt.savefig(output_directory + str(traj_list[0]) + "_heatmap.png")
        plt.clf()

        autocorrelation_plot(traj_list[0].size(burnin_end))
        plt.title(str(traj_list[0].sampling_method))
        plt.savefig(output_directory + str(traj_list[0]) + "_size_autocorr.png")
        plt.clf()


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser("Generates analytics for the Markov chain of decomposable graphs generated "
                                     "by particle Gibbs.")

    parser.add_argument(
        '-f', '--data_filename',
        required=True,
        help="Filename of dataset")
    parser.add_argument(
        '-N', '--n_particles',
        type=int, required=True, nargs='+',
        help="Number of SMC particles")
    parser.add_argument(
        '-a', '--alphas', default=[0.5],
        type=float, required=False, nargs='+',
        help="Parameter for the junction tree expander")
    parser.add_argument(
        '-b', '--betas', default=[0.5],
        type=float, required=False, nargs='+',
        help="Parameter for the junction tree expander")
    parser.add_argument(
        '-M', '--trajectory_length',
        type=int, required=True, nargs='+',
        help="Number of Gibbs samples")
    parser.add_argument(
        '-g', '--graphfile',
        required=False,
        help="The true graph in json-format")
    parser.add_argument(
        '-p', '--precmat',
        required=False,
        help="The true precision matrix")
    parser.add_argument(
        '-r', '--radius',
        type=int, required=False, default=None, nargs='+',
        help="The search neighborhood radius for the SMC sampler")
    parser.add_argument(
        '-o', '--output_directory',
        required=False, default="./",
        help="Output directory")
    parser.add_argument(
        '-i', '--input_directory',
        required=False, default="./",
        help="Input directory")
    parser.add_argument(
        '--burnin_end', type=int, required=False, default=0,
        help="Burn-in period ends at this index")
    parser.add_argument(
        '--reps', type=int, required=False, default=0,
        help='Number of trajectories for each parameter setting')

    args = parser.parse_args()
    main(**args.__dict__)