#!/usr/bin/env python3

import argparse
import os

from logzero import logger as log

from rk.pr_shapes import pr_maker as prm

#------------------------------
class data:
    l_q2bin = ['jpsi', 'psi2', 'high']
    l_trig  = ['ETOS']
    l_sample= ['bdXcHs', 'bpXcHs']

    year    = None
    q2bin   = None
    trig    = None
    sample  = None
    addvar  = None
    ivers   = None
    overs   = None
#------------------------------
def get_args():
    parser = argparse.ArgumentParser(description='Used to produce JSON files with masses needed for making KDE PRec PDFs')
    parser.add_argument('-y', '--year'   , nargs='+', help='Samples year', required=True)
    parser.add_argument('-q', '--q2bin'  , nargs='+', help='q2 bin'  , default=data.l_q2bin)
    parser.add_argument('-t', '--trig'   , nargs='+', help='Trigger' , default=data.l_trig)
    parser.add_argument('-s', '--sample' , nargs='+', help='Sample'  , default=data.l_sample)
    parser.add_argument('-a', '--addvar' , nargs='+', help='Variables to add', default=[])
    parser.add_argument('-i', '--iversion', type =str, help='Input version'  , required=True)
    parser.add_argument('-o', '--oversion', type =str, help='Output version' , required=True)
    args = parser.parse_args()

    data.year   = args.year 
    data.q2bin  = args.q2bin 
    data.trig   = args.trig 
    data.sample = args.sample 
    data.addvar = args.addvar
    data.ivers  = args.iversion 
    data.overs  = args.oversion 
#------------------------------
def check_version():
    prc_dir = os.environ['PRCDIR']
    ver_dir = f'{prc_dir}/{data.overs}'

    if os.path.isdir(ver_dir):
        log.error(f'Version already exists: {ver_dir}')
        raise
#------------------------------
def get_channel(trig):
    if   trig in ['ETOS', 'GTIS']:
        chan = 'ee'
    elif trig in ['MTOS']:
        chan = 'mm'
    else:
        log.error(f'Invalid trigger: {trig}')
        raise

    return chan
#------------------------------
def main():
    check_version()

    for year in data.year:
        for q2bin in data.q2bin:
            for trig in data.trig:
                chan = get_channel(trig) 
                for sample in data.sample:
                    obj = prm(f'{sample}_{chan}', year, trig, data.ivers, q2bin)
                    obj.save_data(version=data.overs, add_vars=data.addvar)
#------------------------------
if __name__ == '__main__':
    get_args()
    main()
