#!/usr/bin/env python

import sys
import torch

from nmtpytorch.config import Options
from nmtpytorch.metrics import Evaluator
from nmtpytorch.utils.misc import load_pt_file

if __name__ == '__main__':
    try:
        pt_file = sys.argv[1]
    except IndexError as ie:
        print('Usage: {} <.ckpt file>'.format(sys.argv[0]))
        sys.exit(1)

    weights, history, opts = load_pt_file(pt_file)

    if not history:
        print('This is not a .ckpt file with history information.')
        sys.exit(1)

    opts = Options.from_dict(opts)

    early_metric = opts.train['eval_metrics'].split(',')[0]

    print('Checkpoint saved at epoch: {} update: {}'.format(history['ectr'],
                                                            history['uctr']))
    for i, loss in enumerate(history['epoch_losses']):
        print('- Epoch {:<3} loss: {:.3f}'.format(i + 1, loss))

    print('- Did {} validations with early-stop metric "{}"'.format(
        history['ectr'],
        early_metric))

    for metric, hist in history['evals'].items():
        best_vctr, best_val = Evaluator.find_best(metric, hist)
        print('- Best {:<10} so far: {:.2f} (Validation {})'.format(metric,
                                                                    best_val,
                                                                    best_vctr))
