#!/usr/bin/env python3

import glob
import os
import re
import tqdm
import math
import numpy
import shutil
import tarfile
import argparse

import pandas            as pnd
import matplotlib.pyplot as plt
import utils_noroot      as utnr

from logzero import logger as log

#-------------------------------------------------------
class data:
    job_name = None
    out_path = None
    good_fit = None
#-------------------------------------------------------
def rename_jsn():
    os.makedirs(f'{data.out_path}/output', exist_ok=True)
    for jsn_path in glob.glob('result_jsn/*.json'):
        jsn_name = os.path.basename(jsn_path)
        os.replace(jsn_path, f'{data.out_path}/output/{jsn_name}')
#-------------------------------------------------------
def untar(tar_path):
    tar = tarfile.open(tar_path)
    tar.extractall()
    tar.close()
#-------------------------------------------------------
def make_json():
    '''
    Will take tarballs, untar them, and put all the JSON files in output directory
    '''

    if not data.job_name:
        return

    if os.path.isdir(f'{data.out_path}/output'):
        log.info('JSON directory found, not making it')
        return

    l_dirname = [ dirname for dirname in glob.glob(f'{data.out_path}/*') if re.match(f'{data.out_path}/'+ '\d{9}', dirname)]
    if len(l_dirname) == 0:
        log.error(f'Found no sandboxes in {data.out_path}')
        raise

    l_tar_path= [ f'{dirname}/result_jsn.tar.gz' for dirname in l_dirname if os.path.isfile( f'{dirname}/result_jsn.tar.gz')]

    if len(l_tar_path) == 0:
        log.error(f'Found no tarballs for {data.job_name}')
        raise

    log.info('Unpacking JSON files')
    for tar_path in tqdm.tqdm(l_tar_path, ascii=' -'):
        try:
            untar(tar_path)
        except tarfile.ReadError:
            log.warning(f'Cannot untar: {tar_path}')
            continue
        rename_jsn()
        shutil.rmtree('result_jsn')
#-------------------------------------------------------
def get_data(json_path, kind):
    '''
    Takes path to result_xxxx.json and returns dictionary with {str : float}
    mapping of parameters, etc
    '''
    d_data = utnr.load_json(json_path)
    d_data = d_data[kind]

    d_data_pars = {key : val        for key, val in d_data.items() if isinstance(val, list) and len(val) == 2}
    d_data_meta = {key : float(val) for key, val in d_data.items() if isinstance(val, (float, bool, int))}

    d_data_rename = {}
    for name, [val, err] in d_data_pars.items():
        d_data_rename[f'{name} value'] = [float(val)]
        d_data_rename[f'{name} error'] = [float(err)]

    d_data_rename.update(d_data_meta)

    return d_data_rename
#-------------------------------------------------------
def get_df(kind):
    l_df = [ pnd.DataFrame(get_data(json_path, kind)) for json_path in get_json_paths()]
    df   = pnd.concat(l_df, axis=0)
    df   = df.reset_index(drop=True)

    return df
#-------------------------------------------------------
def get_moments(ser):
    l_val = ser.tolist() 
    l_val = utnr.remove_outliers(l_val)

    mu = numpy.mean(l_val) 
    sg = numpy.std(l_val) 

    return mu, sg
#-------------------------------------------------------
def get_json_paths():
    json_wc = f'{data.out_path}/output/*.json' if data.job_name is not None else f'{data.out_path}/*/results/result_jsn/*.json'
    l_json_path  = glob.glob(json_wc)
    if len(l_json_path) == 0:
        log.error(f'No JSON file found in: {json_wc}')
        raise

    return l_json_path
#-------------------------------------------------------
def plot(df_pos=None, df_pre=None, var=None):
    os.makedirs(f'{data.out_path}/plots', exist_ok = True)

    sr_val = df_pos[f'{var} value']
    sr_err = df_pos[f'{var} error']
    sr_pre = df_pre[f'{var} value']
    sr_cns = df_pre[f'{var} error']
    sr_pul = (sr_val - sr_pre) / sr_err

    plot_pull(sr_pul, var)
    plot_vals(sr_val, sr_pre[0], sr_cns[0], var)
    plot_errs(sr_err, sr_pre[0], sr_cns[0], var)
    plot_qlty(df_pos)
#-------------------------------------------------------
def plot_pull(sr_pul, var):
    mu, sd = get_moments(sr_pul)

    sr_pul.hist(bins=50, range=(-3, +3), histtype='step')
    plt.axvline(x=mu   , color='red', linestyle='--')
    plt.axvline(x=mu-sd, color='red', linestyle=':')
    plt.axvline(x=mu+sd, color='red', linestyle=':')

    os.makedirs(f'{data.out_path}/plots/pulls', exist_ok=True)

    plot_path=f'{data.out_path}/plots/pulls/{var}.png'
    log.info(f'Saving to: {plot_path}')
    plt.title(f'{var} pull')
    plt.legend(['Pull', '$\mu$', '$\mu-\sigma$', '$\mu+\sigma$'])
    plt.savefig(plot_path)
    plt.close('all')
#-------------------------------------------------------
def get_range(var, error=False):
    if var == 'rk' and not error:
        return (0, 2)
    
    return
#-------------------------------------------------------
def plot_vals(sr_val, gen, cns, var):
    os.makedirs(f'{data.out_path}/plots/values', exist_ok=True)
    plot_path = f'{data.out_path}/plots/values/{var}.png'

    sr_val.plot.hist(bins=50, range=get_range(var, error=False), histtype='step')
    plt.axvline(x=gen, color='red', linestyle='--')
    if cns < 1e-6:
        plt.legend(['Fitted', f'Generated={gen:.3f}'])
    else:
        plt.axvline(x=gen + cns, color='red', linestyle=':')
        plt.axvline(x=gen - cns, color='red', linestyle=':')
        plt.legend(['Fitted', f'Generated={gen:.3f}', '+const', '-const'])

    plt.title(var)
    log.info(f'Saving to: {plot_path}')
    plt.savefig(plot_path)
    plt.close('all')
#-------------------------------------------------------
def plot_errs(sr_val, gen, cns, var):
    os.makedirs(f'{data.out_path}/plots/errors', exist_ok=True)
    plot_path = f'{data.out_path}/plots/errors/{var}.png'

    sr_val.plot.hist(bins=50, range=get_range(var, error=True), histtype='step')
    var = var.replace('_', ' ')
    plt.title(f'$\\varepsilon{{{var}}}$')
    log.info(f'Saving to: {plot_path}')

    if var.startswith('n'):
        err = math.sqrt(gen)
        plt.axvline(x=err, color='red', linestyle='--')
        plt.legend(['Error', '$\sqrt{Generated}$'])

    if cns > 1e-6:
        plt.axvline(x=cns, color='red', linestyle='--')
        plt.legend(['Error', 'Constraint width'])

    if var == 'rk':
        err_exp = sr_val.mean()
        plt.axvline(x=err_exp, color='red', linestyle='--')
        plt.legend(['Error', f'$\mu={err_exp:.3f}$'])

    plt.savefig(plot_path)
    plt.close('all')
#-------------------------------------------------------
def plot_cov(cov, l_var):
    os.makedirs(f'{data.out_path}/plots', exist_ok=True)

    plot_path = f'{data.out_path}/plots/covariance.png'
    utnr.plot_matrix(plot_path, l_var, l_var, cov, upper=True, title='', form='{:.3f}', fsize=[10, 10])

    cor       = utnr.correlation_from_covariance(cov)
    plot_path = f'{data.out_path}/plots/correlation.png'
    utnr.plot_matrix(plot_path, l_var, l_var, cor, upper=True, title='', form='{:.3f}', fsize=[10, 10])
#-------------------------------------------------------
def freq_one(df, quantity):
    sr_qnt = df[quantity]
    ntot   = len(sr_qnt)

    sr_one = sr_qnt == 1
    none   = len(sr_one)

    return none, ntot - none
#-------------------------------------------------------
def plot_qlty(df):
    os.makedirs(f'{data.out_path}/plots/quality', exist_ok=True)
    plot_path = f'{data.out_path}/plots/quality/summary.png'

    cnv_y, cnv_n = freq_one(df, 'converged') 
    sta_y, sta_n = freq_one(df, 'status') 
    val_y, val_n = freq_one(df, 'valid') 

    xerr = [0.5, 0.5, 0.5]
    xval = [1.0, 2.0, 3.0]
    plt.errorbar(xval, [cnv_y, sta_y, val_y], xerr=xerr, label='Good' , marker='o', linestyle='None')
    plt.errorbar(xval, [cnv_n, sta_n, val_n], xerr=xerr, label='Bad'  , marker='o', linestyle='None')

    plt.title('Fit quality')
    log.info(f'Saving to: {plot_path}')
    plt.grid()
    plt.legend()
    plt.xticks(xval, ['Converged', 'Status', 'Valid'])
    plt.savefig(plot_path)
    plt.close('all')
#-------------------------------------------------------
def filter_df(df):
    return df
    if not data.good_fit:
        log.info('Not filtering dataframe')
        return df

    log.warning('Filtering dataframe')

    df=df[df.valid     == 1]
    df=df[df.converged == 1]
    df=df[df.status    == 0] 

    df=df.reset_index(drop=True)

    return df
#-------------------------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Will make plots from the results of toy fits')
    parser.add_argument('-n','--job_name', type=str, help='Name of job, for grid jobs')
    parser.add_argument('-p','--job_path', type=str, help='Path to job output, for IHEP tests')
    parser.add_argument('-g','--good_fit', help='Will plot only good fits', action='store_true')
    args = parser.parse_args()

    if args.job_name is None and args.job_path is None:
        log.error(f'Neither job name or job path passed')
        raise

    data.job_name = args.job_name
    data.out_path = f'output_{data.job_name}' if args.job_name else args.job_path
    data.good_fit = args.good_fit
#-------------------------------------------------------
def get_covariance(json_path):
    d_data = utnr.load_json(json_path)
    cov    = d_data['pos']['cov']
    cov    = numpy.array(cov) 
    cov    = cov.astype('float')

    return cov 
#-------------------------------------------------------
def build_covariance():
    data_name = f'{data.out_path}/covariance.json'
    if os.path.isfile(data_name):
        mat = utnr.load_json(data_name)
        mat = numpy.array(mat)
        return mat

    l_mat = [ get_covariance(json_path) for json_path in get_json_paths() ]
    mat   = numpy.mean(l_mat, axis=0)
    utnr.dump_json(mat.tolist(), data_name)

    return mat 
#-------------------------------------------------------
def build_df(kind):
    data_name = f'{data.out_path}/{kind}.json'
    if not os.path.isfile(data_name):
        make_json()
        df=get_df(kind)
        df.to_json(data_name, indent=4)

    df = pnd.read_json(data_name)
    df = filter_df(df)

    return df
#-------------------------------------------------------
def main():
    df_pos = build_df('pos')
    df_pre = build_df('pre')

    l_pos  = df_pos.columns.tolist()
    l_pre  = df_pre.columns.tolist()

    s_var = set(l_pos).intersection(l_pre)
    l_var = { var.replace(' value', '').replace(' error', '') for var in s_var }

    for var in l_var:
        plot(df_pos=df_pos, df_pre=df_pre, var=var)

    l_var_only = [ pos.replace(' value', '') for pos in l_pos if 'value' in pos and 'nsg_ee' not in pos]
    cov        = build_covariance()
    plot_cov(cov, l_var_only)
#-------------------------------------------------------
if __name__ == '__main__':
    get_args()
    main()

