#!/usr/bin/env python
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import sys
import os
import momma_dragonn
from collections import OrderedDict
import avutils.file_processing as fp

def momma_dragonn_train(options):
    valid_data_loader = momma_dragonn.loaders.load_data_loader(
                            config=options.valid_data_loader_config)
    model_evaluator = momma_dragonn.loaders.load_model_evaluator(
                            options.evaluator_config)
    larger_is_better = model_evaluator.is_larger_better_for_key_metric()

    end_of_epoch_callbacks = momma_dragonn.loaders.load_epoch_callbacks(
                                config=options.end_of_epoch_callbacks_config)
    start_of_epoch_callbacks = ([] if options.start_of_epoch_callbacks_config is None
                                else momma_dragonn.loaders.load_epoch_callbacks(
                                config=options.start_of_epoch_callbacks_config))

    end_of_training_callbacks =\
        momma_dragonn.loaders.load_end_of_training_callbacks(
            config=options.end_of_training_callbacks_config,
            key_metric_name=model_evaluator.get_key_metric_name(),
            larger_is_better=larger_is_better)

    for hyperparameter_setting_yaml in\
        fp.load_yaml_if_string(options.hyperparameter_configs_list):
        hyperparameter_setting =\
            momma_dragonn.loaders.load_hyperparameter_block(
                                  hyperparameter_setting_yaml)
        other_data_loaders = hyperparameter_setting["other_data_loaders"]
        model_creator = hyperparameter_setting["model_creator"]
        model_trainer = hyperparameter_setting["model_trainer"]
        message = hyperparameter_setting["message"]

        model_wrapper, performance_history, training_metadata =\
            model_trainer.train(
                model_creator=model_creator,
                model_evaluator=model_evaluator,
                valid_data_loader=valid_data_loader,
                other_data_loaders=other_data_loaders,
                end_of_epoch_callbacks=end_of_epoch_callbacks,
                start_of_epoch_callbacks=start_of_epoch_callbacks,
                error_callbacks=[])

        model_trainer_config = model_trainer

        #if trained for at least one epoch:
        if (performance_history.get_best_valid_epoch_perf_info is not None):
            for end_of_training_callback in end_of_training_callbacks:
                end_of_training_callback( #handles writing to db
    performance_history=performance_history,
    model_wrapper=model_wrapper,
    training_metadata=training_metadata,
    message=message,
    model_creator_info=model_creator.get_jsonable_object(),
    model_trainer_info=model_trainer.get_jsonable_object(),
    other_data_loaders_info=\
        OrderedDict([(split_name, data_loader.get_jsonable_object())
        for (split_name, data_loader) in other_data_loaders.items()]))

if __name__ == "__main__":
    import argparse;
    parser = argparse.ArgumentParser()
    parser.add_argument("--valid_data_loader_config", default="config/valid_data_loader_config.yaml")
    parser.add_argument("--evaluator_config", default="config/evaluator_config.yaml")
    parser.add_argument("--end_of_epoch_callbacks_config", default="config/end_of_epoch_callbacks_config.yaml")
    parser.add_argument("--start_of_epoch_callbacks_config", default=None)
    parser.add_argument("--end_of_training_callbacks_config", default="config/end_of_training_callbacks_config.yaml")
    parser.add_argument("--hyperparameter_configs_list", default="config/hyperparameter_configs_list.yaml")
    options = parser.parse_args();
    momma_dragonn_train(options)
