#!/usr/bin/env python3

import hydra
import importlib
import os
import torch
import transformers
import argparse
from pathlib import Path
import json

import logging

log = logging.getLogger('lm_polygraph')

from lm_polygraph.utils.manager import UEManager
from lm_polygraph.utils.dataset import Dataset
from lm_polygraph.utils.model import WhiteboxModel, create_ensemble
from lm_polygraph.utils.processor import Logger
from lm_polygraph.generation_metrics import *
from lm_polygraph.estimators import *
from lm_polygraph.utils.openai_chat import OpenAIChat
from lm_polygraph.estimators.ensemble_token_measures import all_token_estimators
from lm_polygraph.estimators.ensemble_sequence_measures import all_ep_estimators, all_pe_estimators
from lm_polygraph.estimators.ensemble_token_measures import *
from lm_polygraph.ue_metrics import *

hydra_config = Path(os.environ["HYDRA_CONFIG"])

@hydra.main(
    version_base=None,
    config_path=str(hydra_config.parent),
    config_name=str(hydra_config.name),
)
def main(args):
    save_path = os.getcwd()
    log.info(f"Main directory: {save_path}")
    os.chdir(hydra.utils.get_original_cwd())

    save_path = args.save_path if "save_path" in args else save_path

    if args.seed is None or len(args.seed) == 0:
        args.seed = [1]

    model_kwargs = get_model_kwargs(args)

    cache_kwargs = {}
    if os.environ.get('HF_DATASETS_OFFLINE', '').strip() == '1':
        cache_kwargs = {'cache_dir': args.cache_path}

    for seed in args.seed:
        log.info("=" * 100)
        log.info(f"SEED: {seed}")

        log.info(f"Loading model {args.model.path}...")
        transformers.set_seed(seed)
        
        model = WhiteboxModel.from_pretrained(
            args.model.path,
            generation_params=getattr(args, "generation_params", {}),
            **cache_kwargs,
            **model_kwargs
        )

        if args.model.ensemble:
            # Only MC-ensembles for now
            log.info(f"Creating ensemble...")
            ensemble_model = create_ensemble(model_paths=[args.model.path],
                                             mc=True,
                                             seed=args.seed[0],
                                             ensembling_mode=args.model.ensembling_mode,
                                             mc_seeds=args.model.mc_seeds,
                                             dropout_rate=float(args.model.dropout_rate),
                                             **cache_kwargs,
                                             **model_kwargs
                                             )
        else:
            ensemble_model = None

        log.info("Done with loading model.")

        log.info(f"Loading dataset {args.dataset}...")
        dataset = Dataset.load(
            args.dataset,
            args.text_column,
            getattr(args, "label_column", None),
            batch_size=args.batch_size,
            prompt=getattr(args, "prompt", ""),
            description=getattr(args, "description", ""),
            mmlu_max_subject_size=getattr(args, "mmlu_max_subject_size", 100),
            n_shot=getattr(args, "n_shot", 5),
            few_shot_split=getattr(args, "few_shot_split", "train"),
            split=args.eval_split,
            load_from_disk=args.load_from_disk,
            **cache_kwargs
        )
        log.info("Done with loading eval data.")

        log.info("="*100)
        log.info("Initializing UE estimators...")
        estimators = []
        estimators += get_ue_methods(args, model)
        density_based_ue_methods = get_density_based_ue_methods(args, model.model_type)
        estimators += density_based_ue_methods
        log.info("Done loading UE estimators")

        if any([not getattr(method, "is_fitted", False) for method in density_based_ue_methods]):
            log.info("="*100)
            log.info(f"Loading train dataset...")
            if (args.train_dataset is not None) and (
                    args.train_dataset != args.dataset
            ):
                train_dataset = Dataset.load(
                    args.train_dataset,
                    args.text_column,
                    getattr(args, "label_column", None),
                    batch_size=args.batch_size,
                    prompt=getattr(args, "prompt", ""),
                    description=getattr(args, "description", ""),
                    mmlu_max_subject_size=getattr(args, "mmlu_max_subject_size", 100),
                    n_shot=getattr(args, "n_shot", 5),
                    few_shot_split=getattr(args, "few_shot_split", "train"),
                    split=args.train_split,
                    size=10_000,
                    load_from_disk=args.load_from_disk,
                    **cache_kwargs
                )
            elif args.train_test_split:
                X_train, X_test, y_train, y_test = dataset.train_test_split(
                    test_size=args.test_split_size, seed=seed, split=args.eval_split
                )
                train_dataset = Dataset(
                    x=X_train, y=y_train, batch_size=args.batch_size
                )
            else:
                train_dataset = Dataset.load(
                    args.dataset,
                    args.text_column,
                    getattr(args, "label_column", None),
                    batch_size=args.batch_size,
                    prompt=getattr(args, "prompt", ""),
                    description=getattr(args, "description", ""),
                    mmlu_max_subject_size=getattr(args, "mmlu_max_subject_size", 100),
                    n_shot=getattr(args, "n_shot", 5),
                    few_shot_split=getattr(args, "few_shot_split", "train"),
                    split=args.train_split,
                    size=10_000,
                    load_from_disk=args.load_from_disk,
                    **cache_kwargs
                )

            background_train_dataset = Dataset.load(
                args.background_train_dataset,
                args.background_train_dataset_text_column,
                args.background_train_dataset_label_column,
                batch_size=args.batch_size,
                data_files=args.background_train_dataset_data_files,
                split="train",
                size=100_000,
                load_from_disk=args.background_load_from_disk,
                **cache_kwargs
            )

            if args.subsample_train_dataset != -1:
                train_dataset.subsample(args.subsample_train_dataset, seed=seed)
            if args.subsample_background_train_dataset != -1:
                background_train_dataset.subsample(
                    args.subsample_background_train_dataset, seed=seed
                )
            log.info(f"Done loading train data.")
        else:
            train_dataset = None
            background_train_dataset = None

        if args.subsample_eval_dataset != -1:
            dataset.subsample(args.subsample_eval_dataset, seed=seed)

        generation_metrics = get_generation_metrics(args)

        ue_metrics = get_ue_metrics(args)

        man = UEManager(
            dataset,
            model,
            estimators,
            generation_metrics,
            ue_metrics,
            [
                Logger(),
            ],
            deberta_batch_size=getattr(args, 'deberta_batch_size', 10),
            train_data=train_dataset,
            ignore_exceptions=args.ignore_exceptions,
            background_train_data=background_train_dataset,
            max_new_tokens=args.max_new_tokens,
            ensemble_model=ensemble_model,
            cache_path=args.cache_path,
        )

        man()

        man.save(save_path + f"/ue_manager_seed{seed}")


def get_ue_metrics(args):
    ue_metrics = [
        ReversedPairsProportion(),
        PredictionRejectionArea(),
        RiskCoverageCurveAUC(),
    ]
    if getattr(args, "use_claim_ue", False):
        ue_metrics += [
            ROCAUC(),
            PRAUC(),
        ]
    return ue_metrics


def get_density_based_ue_methods(args, model_type):
    estimators = []
    if args.use_density_based_ue:
        if getattr(args, 'parameters_path', False):
            parameters_path = args.parameters_path
        else:
            dataset_name = args.dataset if isinstance(args.dataset, str) else '_'.join(args.dataset)
            dataset_name = dataset_name.split("/")[-1].split(".")[0]
            model_name = args.model.path.split("/")[-1]
            parameters_path = f"{args.cache_path}/density_stats/{dataset_name}/{model_name}"
        
        if model_type == "Seq2SeqLM":
            estimators += [
                MahalanobisDistanceSeq("encoder", parameters_path=parameters_path),
                MahalanobisDistanceSeq("decoder", parameters_path=parameters_path),
                RelativeMahalanobisDistanceSeq(
                    "encoder", parameters_path=parameters_path
                ),
                RelativeMahalanobisDistanceSeq(
                    "decoder", parameters_path=parameters_path
                ),
                RDESeq("encoder", parameters_path=parameters_path),
                RDESeq("decoder", parameters_path=parameters_path),
                PPLMDSeq("encoder", md_type="MD", parameters_path=parameters_path),
                PPLMDSeq("encoder", md_type="RMD", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="MD", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="RMD", parameters_path=parameters_path),
            ]
        else:
            estimators += [
                MahalanobisDistanceSeq("decoder", parameters_path=parameters_path),
                RelativeMahalanobisDistanceSeq(
                    "decoder", parameters_path=parameters_path
                ),
                RDESeq("decoder", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="MD", parameters_path=parameters_path),
                PPLMDSeq("decoder", md_type="RMD", parameters_path=parameters_path),
            ]
    return estimators


def get_ue_methods(args, model):
    estimators = []
    if args.use_seq_ue:
        estimators += [
            MaximumSequenceProbability(),
            Perplexity(),
            MeanTokenEntropy(),
            MeanPointwiseMutualInformation(),
            MeanConditionalPointwiseMutualInformation(),
            ClaimConditionedProbability(),
            PTrue(),
            PTrueSampling(),
            MonteCarloSequenceEntropy(),
            MonteCarloNormalizedSequenceEntropy(),
            LexicalSimilarity(metric="rouge1"),
            LexicalSimilarity(metric="rouge2"),
            LexicalSimilarity(metric="rougeL"),
            LexicalSimilarity(metric="BLEU"),
            NumSemSets(),
            EigValLaplacian(similarity_score="NLI_score", affinity="entail"),
            EigValLaplacian(similarity_score="NLI_score", affinity="contra"),
            EigValLaplacian(similarity_score="Jaccard_score"),
            DegMat(similarity_score="NLI_score", affinity="entail"),
            DegMat(similarity_score="NLI_score", affinity="contra"),
            DegMat(similarity_score="Jaccard_score"),
            Eccentricity(similarity_score="NLI_score", affinity="entail"),
            Eccentricity(similarity_score="NLI_score", affinity="contra"),
            Eccentricity(similarity_score="Jaccard_score"),
            SemanticEntropy(),
            SAR(),
            TokenSAR(),
            SentenceSAR(),
            RenyiNeg(),
            FisherRao(),
        ]

    if args.use_ens_ue:
        if not (model.model_type == "Seq2SeqLM"):
            raise NotImplementedError('Only Encoder-Decoder models can be ensembled at this time')

        token_measures = all_token_estimators()
        if args.model.ensembling_mode == 'pe':
            sequence_measures = all_pe_estimators()
        elif args.model.ensembling_mode == 'ep':
            sequence_measures = all_ep_estimators()
        else:
            raise ValueError(f'Ensemble type should be one of: "pe", "ep", but is {args.ens_type} instead')
        estimators += (token_measures + sequence_measures)

    if args.use_tok_ue:
        estimators += [
            MaximumTokenProbability(),
            TokenEntropy(),
            PointwiseMutualInformation(),
            ConditionalPointwiseMutualInformation(),
            SemanticEntropyToken(model.model_path, args.cache_path),
        ]

    if getattr(args, "use_claim_ue", False):
        estimators += [
            MaximumClaimProbability(),
            PerplexityClaim(),
            MaxTokenEntropyClaim(),
            PointwiseMutualInformationClaim(),
            PTrueClaim(),
            ClaimConditionedProbabilityClaim(nli_context="no_context"),
            ClaimConditionedProbabilityClaim(nli_context="fact_pref"),
        ]

    additional_estimators = getattr(args, "additional_estimators", {})
    additional_estimators_kwargs = getattr(args, "additional_estimators_kwargs", {})

    for i, (module_name, estimator_classes) in enumerate(additional_estimators.items()):
        module = importlib.import_module(module_name)
        for j, estimator_class in enumerate(estimator_classes):
            try:
                estimator_kwargs = additional_estimators_kwargs[estimator_class]
            except KeyError:
                raise TypeError(f'Arguments for {estimator} were not passed')

            estimators.append(getattr(module, estimator_class)(**estimator_kwargs))

    return estimators


def get_generation_metrics(args):
    log.info("="*100)
    log.info("Initializing generation metrics...")

    generation_metrics = getattr(args, "generation_metrics", None)
    if not generation_metrics:
        result = [
            RougeMetric("rouge1"),
            RougeMetric("rouge2"),
            RougeMetric("rougeL"),
            BertScoreMetric('rh'),
            SbertMetric(),
            AccuracyMetric(
                target_ignore_regex = getattr(args, "target_ignore_regex", None),
                output_ignore_regex = getattr(args, "output_ignore_regex", None),
                normalize = getattr(args, "normalize", False),
            ),
            AlignScore(),
            OpenAIFactCheck(cache_path=args.cache_path),
        ]
        if args.task == "nmt":
            ignore_regex = getattr(args, "source_ignore_regex", None)
            result += [Comet(source_ignore_regex = ignore_regex)]
        if not getattr(args, "multiref", False):
            # Currently, BartScoreSeqMetric does not support multiref
            result.append(BartScoreSeqMetric('rh'))
        else:
            # Wrap each metric in AggregatedMetric
            result = [AggregatedMetric(base_metric=metric) for metric in result]
    else:
        result = []
        for metric in generation_metrics:
            metric_name = metric["name"]
            if getattr(args, "multiref", False) and metric_name == "BartScoreSeqMetric":
                raise ValueError("BartScoreSeqMetric does not support multiref")
            metric_class = globals()[metric_name]
            result.append(metric_class(*metric.get("args", [])))

    log.info("Done with initializing generation metrics.")

    return result


def get_model_kwargs(args):
    model_kwargs = {}
    if getattr(args.model, 'device_map', None):
        model_kwargs['device_map'] = args.model.device_map

    return model_kwargs


if __name__ == "__main__":
    main()
