#!/usr/bin/env python3

from rk.swt_copy  import copy      as swcp
from rk.ds_getter import ds_getter as dsg
from logzero      import logger    as log

import ROOT
import os
import argparse
import pandas       as pnd
import utils_noroot as utnr
#----------------------------------------
class data:
    l_proc = ['data', 'sign', 'ctrl', 'ctrl_pi', 'bd_x', 'bp_x', 'bs_x', 'psi2', 'bdks', 'bpks', 'bdkpi', 'cmb', 'bpkkk', 'bpkpipi']
    l_dset = ['2011', '2012', '2015', '2016', '2017', '2018']
    l_q2bin= ['jpsi', 'psi2', 'high']

    q2bin  = None
    proc   = None
    vers   = None
    dset   = None
    trig   = None
    ipart  = None
    npart  = None
    skip   = None
    prec   = None
    pref   = None
    swgt   = None
#----------------------------------------
class cache_data:
    def __init__(self, proc=None, vers=None, dset=None, trig=None, q2bin = None, ipart=None, npart=None):
        self._proc        = proc
        self._vers        = vers
        self._dset        = dset
        self._trig        = trig
        self._q2bin       = q2bin
        self._ipart       = ipart 
        self._npart       = npart 

        self._selection   = 'all_gorder'
        self._truth_corr  = 'final_no_truth_mass_bdt'

        self._chan        = None
        self._tree        = None
        self._cache_dir   = None
        self._swt_dir     = None

        self._initialized = False
    #----------------------------------------
    def _initialize(self):
        if self._initialized:
            return

        self._setup_vars()

        dat_dir       = os.environ['DATDIR']
        self._swt_dir = f'{dat_dir}/data_{self._chan}/cali_new/sweights_v5'

        self._check_none(self._proc, 'proc')
        self._check_none(self._vers, 'vers')
        self._check_none(self._dset, 'dset')
        self._check_none(self._trig, 'trig')
        self._check_none(self._ipart, 'ipart')
        self._check_none(self._npart, 'npart')

        if 'CASDIR' not in os.environ:
            log.error(f'Caching directory, $CASDIR, not found in environment')
            raise
        else:
            self._cache_dir = os.environ['CASDIR']

        self._initialized = True
    #----------------------------------------
    def _setup_vars(self):
        if   self._trig in ['ETOS', 'GTIS']:
            self._chan = 'ee'
            self._tree = 'KEE'
        elif self._trig in ['MTOS']:
            self._chan = 'mm'
            self._tree = 'KMM'
        else:
            log.error(f'Invalid trigger: {self._trig}')
            raise
    #----------------------------------------
    def _check_none(self, var, name):
        if var is None:
            log.eror(f'Variable {name} not initialized')
            raise ValueError
    #----------------------------------------
    def _cache_path(self, preffix):
        path_dir = f'{self._cache_dir}/tools/apply_selection/{preffix}/{self._proc}/{self._vers}/{self._dset}_{self._trig}'
        os.makedirs(path_dir, exist_ok=True)
        path     = f'{path_dir}/{self._ipart}_{self._npart}.root'

        if os.path.isfile(path):
            log.info(f'Loading cached data: {path}')
            return path, True

        return path, False
    #----------------------------------------
    def _add_swt(self, rdf_tgt):
        swt_path=f'{self._swt_dir}/{self._dset}_dt_trigger_weights_sweighted.root'
        log.info(f'Extracting sweights from: {swt_path}:{self._tree}')
        rdf_src =ROOT.RDataFrame(self._tree, swt_path)
        
        obj=swcp(src=rdf_src, tgt=rdf_tgt)
        rdf=obj.attach_swt(l_wgt_name=[f'sw_{self._trig}'])

        return rdf
    #----------------------------------------
    def save(self, l_skip_cut=None, skip_prec=False, preffix='unnamed', add_swt=False):
        self._initialize()

        ntp_path, is_cached = self._cache_path(preffix)
        if is_cached:
            return

        part       = (self._ipart, self._npart)
        d_redefine = { f'{cut}' : '(1)' for cut in l_skip_cut }

        obj = dsg(self._q2bin, self._trig, self._dset, self._vers, part, self._proc, self._selection)
        rdf = obj.get_df(d_redefine=d_redefine, skip_prec=skip_prec)

        cut_path = ntp_path.replace('.root', '_cut.json')
        eff_path = ntp_path.replace('.root', '_eff.json')
        cfl_path = ntp_path.replace('.root',      '.pkl')

        if add_swt:
            rdf = self._add_swt(rdf)

        log.info(f'Saving to: {eff_path}')
        log.info(f'Saving to: {cut_path}')
        log.info(f'Saving to: {cfl_path}')
        log.info(f'Saving to: {ntp_path}:{self._trig}')

        rdf.cf.df_eff.to_json(eff_path, indent=4)
        rdf.cf.df_cut.to_json(cut_path, indent=4)
        utnr.dump_pickle(rdf.cf, cfl_path)
        rdf.Snapshot(self._trig, ntp_path)
#----------------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to get datasets from OS and SS data, after full selection, but BDT, used for combinatorial PDF studies')
    parser.add_argument('-s', '--samp' , type= str, help='Process'  , required=True, choices=data.l_proc)
    parser.add_argument('-v', '--vers' , type= str, help='Version'  , required=True)
    parser.add_argument('-d', '--dset' , type= str, help='Dataset'  , required=True, choices=data.l_dset)
    parser.add_argument('-q', '--q2bin', type= str, help='q2 bin'   , required=True, choices=data.l_q2bin)
    parser.add_argument('-t', '--trig' , type= str, help='Trigger'  , required=True, choices=['ETOS', 'GTIS', 'MTOS'])
    parser.add_argument('-p', '--part' , nargs='+', help='partition', required=True)
    parser.add_argument('-r', '--skip' , nargs='+', help='Skip cuts', default=[])
    parser.add_argument('-b', '--prec' ,            help='Skip prec bdt', action='store_true')
    parser.add_argument('-f', '--pref' , type= str, help='Name of directory where samples go')
    parser.add_argument('-w', '--swgt' ,            help='Add sweights from old files', action='store_true')

    args = parser.parse_args()

    data.proc = args.samp
    data.vers = args.vers
    data.dset = args.dset
    data.trig = args.trig
    data.skip = args.skip
    data.prec = args.prec
    data.q2bin= args.q2bin
    data.pref = args.pref
    data.swgt = args.swgt

    if data.skip == ['none']:
        data.skip = []

    data.ipart, data.npart = [ int(part) for part in args.part ]
#----------------------------------------
if __name__ == '__main__':
    get_args()
    obj=cache_data(proc=data.proc, vers=data.vers, dset=data.dset, trig=data.trig, q2bin=data.q2bin, ipart=data.ipart, npart=data.npart)
    obj.save(l_skip_cut=data.skip, skip_prec=data.prec, preffix=data.pref, add_swt=data.swgt)

