#!/usr/bin/env python3

import os
import sys
import pickle
import json
import hydra
from typing import Dict, List
from pathlib import Path

import numpy as np

from lm_polygraph.normalizers.minmax import MinMaxNormalizer
from lm_polygraph.normalizers.quantile import QuantileNormalizer
from lm_polygraph.normalizers.binned_pcc import BinnedPCCNormalizer
from lm_polygraph.normalizers.isotonic_pcc import IsotonicPCCNormalizer
from lm_polygraph.utils.normalize import get_mans_ues_metrics, filter_nans


hydra_config = Path(os.environ["HYDRA_CONFIG"])


@hydra.main(
    version_base=None,
    config_path=str(hydra_config.parent),
    config_name=str(hydra_config.name),
)
def fit(args):
    man_paths = args.man_paths
    ue_method_names = args.ue_method_names
    gen_metric_names = args.gen_metric_names

    ues, gen_metrics = get_mans_ues_metrics(man_paths,
                                           ue_method_names,
                                           gen_metric_names)

    fitted_normalizers = {}
    for metric_name, metric_data in gen_metrics.items():
        for ue_method_name, ue_data in ues.items():
            filtered_gen_metrics, filtered_ues = filter_nans(metric_data, ue_data)

            for normalization_method in args.normalization_methods:
                if normalization_method == "min_max":
                    normalizer = MinMaxNormalizer()
                    normalizer.fit(filtered_ues)
                elif normalization_method == "quantile":
                    normalizer = QuantileNormalizer()
                    normalizer.fit(filtered_ues)
                elif normalization_method == "binned_pcc":
                    normalizer = BinnedPCCNormalizer()
                    normalizer.fit(filtered_gen_metrics,
                                   filtered_ues,
                                   args.num_bins)
                elif normalization_method == "isotonic_pcc":
                    normalizer = IsotonicPCCNormalizer()
                    normalizer.fit(filtered_gen_metrics, filtered_ues)
                else:
                    raise ValueError(f"Unknown normalization method: {normalization_method}")

                str_normalizer = normalizer.dumps()

                fitted_normalizers[(metric_name, ue_method_name, normalization_method)] = str_normalizer
    
    Path(args.save_path).mkdir(parents=True, exist_ok=True)
    with open(args.save_path + '/fitted_normalizers.json', "wb") as f:
        pickle.dump(fitted_normalizers, f)


if __name__ == "__main__":
    fit()
