#!/usr/bin/env python3

import os
import re
import ROOT
import zfit
import glob
import math
import numpy
import pprint
import argparse
import pandas              as pnd
import jacobi              as jac
import logzero
import tarfile
import rk.utilities        as rkut
import utils_noroot        as utnr

from rkex_model import model
from np_reader  import np_reader as np_rdr
from mc_reader  import mc_reader as mc_rdr
from cs_reader  import cs_reader as cs_rdr
from extractor  import extractor as ext
from logzero    import logger    as log
from zutils     import utils     as zut

#--------------------------------
class data:
    out_dir = 'results' 
    l_seed  = None
    l_dset  = None
    log_lvl = None 
#--------------------------------
def get_ne(suffix, d_pos):
    ck_name     = f'ck_{suffix}'
    suffix_tos  = suffix.replace('_TIS_', '_TOS_')
    nsg_mm_name = f'nsg_mm_{suffix_tos}'

    ck , _      = d_pos[ck_name]
    nsg_mm, _   = d_pos[nsg_mm_name]
    rk, _       = d_pos['rk']
    l_val_str   = [ck, nsg_mm, rk]
    l_val       = [ float(val_str) for val_str in l_val_str ] 

    l_ind       = [ d_pos['par'].index(par) for par in [ck_name, nsg_mm_name, 'rk'] ]
    cov         = d_pos['cov']
    cov         = numpy.array(cov)
    cov         = cov.astype(float)
    cov         = cov[[l_ind]][0][:, l_ind]

    nsg_ee_val, nsg_ee_var = jac.propagate(lambda x : (x[0] * x[1]) / x[2], l_val, cov ) 
    nsg_ee_err = math.sqrt(nsg_ee_var)

    nsg_ee_val = float(nsg_ee_val)
    nsg_ee_err = float(nsg_ee_err)

    return [nsg_ee_val, nsg_ee_err]
#--------------------------------
def add_ne(d_pos):
    regex='nsg_mm_(.*_TOS_.*)'
    d_pos_ext = {}
    for var_name in d_pos:
        mtch = re.match(regex, var_name)
        if not mtch:
            continue

        nsg_mm_name= mtch.group(0)
        suffix_tos = mtch.group(1)
        suffix_tis = mtch.group(1).replace('_TOS_', '_TIS_')

        d_pos_ext[f'nsg_ee_{suffix_tos}'] = get_ne(suffix_tos, d_pos) 
        if f'ck_{suffix_tis}' in d_pos:
            d_pos_ext[f'nsg_ee_{suffix_tis}'] = get_ne(suffix_tis, d_pos)
        else:
            log.warning(f'TIS ck not found, skiping electron TIS yield')

    d_pos.update(d_pos_ext)

    return d_pos
#--------------------------------
def fit(rseed):
    log.info(f'Seed: {rseed:04}')

    preffix = 'toys' 

    rdr_mc          = mc_rdr(version='v4', real_data=False)
    rdr_mc.cache    = True 
    rdr_mc.cache_dir= 'v4_mcrdr'
    d_mcmu          = rdr_mc.get_parameter(name='mu')
    d_mcsg          = rdr_mc.get_parameter(name='sg')

    rdr_dt          = mc_rdr(version='v4', real_data=True)
    rdr_dt.cache    = True 
    rdr_dt.cache_dir= 'v4_dtrdr' 
    d_dtmu          = rdr_dt.get_parameter(name='mu')
    d_dtsg          = rdr_dt.get_parameter(name='sg')

    rdr          = np_rdr(sys='v65', sta='v63', yld='v24')
    rdr.cache    = True
    rdr.cache_dir= 'v65_v63_v24'
    cv_sys       = rdr.get_cov(kind='sys')
    cv_sta       = rdr.get_cov(kind='sta')
    d_eff        = rdr.get_eff()
    d_rjpsi      = rdr.get_rjpsi()
    d_byld       = rdr.get_byields()
    d_nent       = rkut.average_byields(d_byld, l_exclude=['TIS'])
    d_rare_yld   = rkut.reso_to_rare(d_nent, kind='jpsi')

    rdr          = cs_rdr(version='v4', preffix=preffix)
    rdr.cache    = True 
    rdr.cache_dir= 'v4_csrdr'
    d_val, d_var = rdr.get_constraints()

    mod          = model(preffix=preffix, d_eff=d_eff,  d_mcmu=d_mcmu, d_mcsg=d_mcsg, d_nent=d_rare_yld, d_dtmu=d_dtmu, d_dtsg=d_dtsg)
    d_mod        = mod.get_model()
    d_dat        = mod.get_data(rseed=rseed)
    d_pre        = mod.get_prefit_pars(d_var=d_var, ck_cov=cv_sys+cv_sta)

    obj          = ext(dset=data.l_dset, drop_correlations=False)
    obj.plt_dir  = f'plots/fits_{rseed:03}'
    obj.rjpsi    = d_rjpsi
    obj.eff      = d_eff
    obj.data     = d_dat
    obj.model    = d_mod 
    obj.cov      = cv_sys + cv_sta
    obj.const    = d_val, d_var
    result       = obj.get_fit_result()

    log.info(f'Calculating errors')
    result.hesse()
    d_pos = rkut.result_to_dict(result) 
    d_pos = add_ne(d_pos)
    result.freeze()

    return result, {'pre' : d_pre, 'pos' : d_pos} 
#--------------------------------
def initialize():
    log.setLevel(data.log_lvl)
    data.l_seed  = get_seeds()
    os.makedirs(data.out_dir, exist_ok=True)
#--------------------------------
def cleanup_env():
    d_par = zfit.Parameter._existing_params
    l_key = list(d_par.keys())

    for key in l_key:
        del(d_par[key])
#--------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used run toy fits on model used to extract RK')
    parser.add_argument('-l', '--level' , type =int, help='Logging level', choices=[logzero.DEBUG, logzero.INFO, logzero.WARNING], default=logzero.INFO)
    parser.add_argument('-d', '--dset'  , nargs='+', help='Datasets to use') 
    args = parser.parse_args()

    data.log_lvl = args.level
    data.l_dset  = args.dset
#--------------------------------
def main():
    get_args()
    initialize()
    l_res   = []
    for rseed in data.l_seed:
        res, d_inf = fit(rseed)
        print(res)
        utnr.dump_pickle(res, f'{data.out_dir}/result_pkl/result_{rseed:04}.pkl')
        utnr.dump_json(d_inf, f'{data.out_dir}/result_jsn/result_{rseed:04}.json')

        cleanup_env()

    with tarfile.open(f'{data.out_dir}/result_pkl.tar.gz', 'w:gz') as tar:
        tar.add(f'{data.out_dir}/result_pkl', arcname='result_pkl')

    with tarfile.open(f'{data.out_dir}/result_jsn.tar.gz', 'w:gz') as tar:
        tar.add(f'{data.out_dir}/result_jsn', arcname='result_jsn')
#--------------------------------
def get_file_seeds(seed_file):
    l_seed = []
    with open(seed_file) as ifile:
        l_seed = ifile.read().splitlines()

    return l_seed
#--------------------------------
def get_seeds():
    l_seed_file = glob.glob('*.sd')
    l_seed  = []
    for seed_file in l_seed_file:
        l_seed += get_file_seeds(seed_file)

    if len(l_seed) == 0:
        log.error(f'No seeds found')
        raise

    log.debug(f'Using seeds: {l_seed}')

    l_seed_int = [ int(rseed) for rseed in l_seed ]

    return l_seed_int
#--------------------------------
if __name__ == '__main__':
    main()

