#!python

import os
import ROOT
import glob
import pprint
import argparse
import utils_noroot      as utnr
import pandas            as pnd
import matplotlib.pyplot as plt

from log_store import log_store

log = log_store.add_logger('rx_selection:check_sl_rejection')
#----------------------------
class data:
    l_year = ['2012', '2018']
    d_samp = {
            'bpd0kppienu'  : r'$B^+\to D^0(\to K^+ \pi^-  ) e^+\nu$',
            'bpd0kpenuenu' : r'$B^+\to D^0(\to K^+ e^- \nu) e^+\nu$', 
            'bpd0kpenupi'  : r'$B^+\to D^0(\to K^+ e^- \nu)  \pi^+$', 
            }

    l_cut_v= [
            'kl > 1885',
            '(kl_M_ltrack2pi < 1825 || kl_M_ltrack2pi > 1905)'
            ]

    l_cut_n= [
            r'$m(K^{\pm} e^{\mp}) > 1885$', 
            r'$m(K^{\pm} e^{\mp}_{e\\to\pi})\notin [1825, 1905]$',
            ]

    vers   = 'v10.21p2'
    trig   = 'ETOS'
    cas_dir= os.environ['CASDIR'] 
    out_dir= 'sl_plots'
    nbins  = 50
#----------------------------
def get_rdf(samp):
    data_dir = f'{data.cas_dir}/tools/apply_selection/sl_bkg_rej/{samp}/v10.21p2'
    l_data_wc= [ f'{data_dir}/{year}_{data.trig}/*.root' for year in data.l_year ]
    l_data   = []
    for data_wc in l_data_wc:
        l_data  += glob.glob(data_wc)

    rdf = ROOT.RDataFrame(data.trig, l_data_wc)

    return rdf
#----------------------------
def add_lines(arr_val, samp, kind):
    if kind == 'kl':
        plt.axvline(x=1885, color='r', linestyle=':', label='Veto')
    else:
        plt.axvline(x=1825, color='r', linestyle=':', label='Veto')
        plt.axvline(x=1905, color='r', linestyle=':')

    if   samp != 'bpd0kppienu' and kind == 'kl':
        plt.annotate('', xy=(1885, 60), xytext=(2085, 60), arrowprops=dict(arrowstyle='<-', color='r'))
    elif samp != 'bpd0kppienu' and kind != 'kl':
        plt.annotate('', xy=(1905, 50), xytext=(2105, 50), arrowprops=dict(arrowstyle='<-', color='r'))
        plt.annotate('', xy=(1625, 50), xytext=(1825, 50), arrowprops=dict(arrowstyle='->', color='r'))
    elif samp == 'bpd0kppienu' and kind == 'kl':
        plt.annotate('', xy=(1885,400), xytext=(1925,400), arrowprops=dict(arrowstyle='<-', color='r'))
    elif samp == 'bpd0kppienu' and kind != 'kl':
        plt.annotate('', xy=(1905,200), xytext=(1935,200), arrowprops=dict(arrowstyle='<-', color='r'))
        plt.annotate('', xy=(1825,200), xytext=(1795,200), arrowprops=dict(arrowstyle='<-', color='r'))
#----------------------------
def get_range(samp):
    if samp in ['bpd0kpenuenu', 'bpd0kpenupi']:
        minx = 1000
        maxx = 2500
    else:
        minx = 1800
        maxx = 2000

    width = (maxx - minx) / data.nbins

    return (minx, maxx), f'{width:.3f}'
#----------------------------
def plot_v1(rdf, proc, samp):
    rng, width = get_range(samp)

    arr_val= rdf.AsNumpy(            ['kl'])[            'kl']
    plt.hist(arr_val, histtype='step', color='b', range=rng, bins=data.nbins, label=proc)
    add_lines(arr_val, samp, 'kl')

    plt.legend(loc='best')
    plt.xlabel('$m(K^{\pm} e^{\mp})$ [MeV]')
    plt.ylabel(f'Entries/{width}MeV')
    log.info(f'Saving to: {data.out_dir}/{samp}_kl_mass.png')
    plt.savefig(f'{data.out_dir}/{samp}_kl_mass.png')
    plt.close('all')
#----------------------------
def plot_v2(rdf, proc, samp):
    rng, width = get_range(samp)

    arr_val= rdf.AsNumpy(['kl_M_ltrack2pi'])['kl_M_ltrack2pi']
    plt.hist(arr_val, histtype='step', color='b', range=rng, bins=data.nbins, label=proc)
    add_lines(arr_val, samp, 'kp')

    plt.legend(loc='best')
    plt.xlabel('$m(K^{\pm} e^{\mp}_{e\\to\pi})$ [MeV]')
    plt.ylabel(f'Entries/{width}MeV')
    log.info(f'Saving to: {data.out_dir}/{samp}_kltopi_mass.png')
    plt.savefig(f'{data.out_dir}/{samp}_kltopi_mass.png')
    plt.close('all')
#----------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to check rejection due to cascade vetos in SL samples')
    args   = parser.parse_args()
#----------------------------
def get_efficiencies(rdf):
    ntot = rdf.Count().GetValue()

    l_eff = []
    for cut in data.l_cut_v:
        rdf = rdf.Filter(cut)
        nev = rdf.Count().GetValue()
        l_eff.append(nev/ntot)

    rep = rdf.Report()
    rep.Print()

    return l_eff 
#----------------------------
def main():
    get_args()
    os.makedirs(data.out_dir, exist_ok=True)

    d_eff = {}
    for samp, name in data.d_samp.items():
        rdf         = get_rdf(samp)
        d_eff[name] = get_efficiencies(rdf)

        plot_v1(rdf, name, samp)
        plot_v2(rdf, name, samp)

    df = pnd.DataFrame(d_eff)
    df = df.rename(index = {0 : data.l_cut_n[0], 1 : data.l_cut_n[1]})
    df = df.T
    d_form = {data.l_cut_n[0] : '{:.3f}', data.l_cut_n[1] : '{:.3f}'}
    utnr.df_to_tex(df, f'{data.out_dir}/efficiencies.tex', hide_index=False, d_format=d_form)
#----------------------------
if __name__ == '__main__':
    main()

