#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Computes the BLEU, ROUGE, METEOR, and CIDER using the COCO metrics scripts
"""
import pathlib
import argparse
from collections import OrderedDict

# Script taken and adapted from Kelvin Xu's arctic-captions project
# https://github.com/kelvinxu/arctic-captions

from nmtpytorch.cocoeval import Bleu, Meteor, Cider, Rouge
from nmtpytorch.utils.misc import get_meteor_jar


def print_table(results, sort_by='METEOR'):
    cols = ['Bleu_1', 'Bleu_2', 'Bleu_3', 'Bleu_4',
            'METEOR', 'CIDEr', 'ROUGE_L']
    for col in cols:
        print('|{:^15}|'.format(col), end='')
    print()

    results = sorted(results.items(), key=lambda x: x[1][sort_by])

    for sysname, result in results:
        if len(results) > 1:
            print(sysname)
        for col in cols:
            print('|{:^15,.3f}|'.format(result[col]), end='')
        print()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(prog='coco-metrics')

    parser.add_argument("-w", "--write", action='store_true',
                        help='Create a .score file containing the results.')
    parser.add_argument("-l", "--language", default='en',
                        help='Hypothesis language (default: en)')
    parser.add_argument("-r", "--refs", type=argparse.FileType('r'),
                        help="Path to all the reference files", nargs='+')
    parser.add_argument("systems",  type=str,
                        help="Per-system hypothesis file(s)", nargs='+')

    args = parser.parse_args()

    # Check for METEOR
    get_meteor_jar()

    # List of scorers
    scorers = [
        (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
        (Meteor(args.language), ["METEOR"]),
        (Cider(), ["CIDEr"]),
        (Rouge(), ["ROUGE_L"]),
    ]

    results = OrderedDict()

    # Read multiple reference files
    raw_refs = [list(map(str.strip, r)) for r in zip(*args.refs)]
    refs = {idx: rr for idx, rr in enumerate(raw_refs)}

    # Ranking of multiple systems is possible
    for hypfile in args.systems:
        with open(hypfile) as f:
            # List of hypothesis sentences for this system
            hypo = {idx: [line.strip()] for (idx, line) in enumerate(f)}

            result = OrderedDict()

            for scorer, method in scorers:
                score, _ = scorer.compute_score(refs, hypo)
                if score:
                    if not isinstance(score, list):
                        score = [score]
                    for m, s in zip(method, score):
                        result[m] = float('%.3f' % s)

            if args.write:
                with open("%s.score" % hypfile, 'w') as f:
                    f.write("%s\n" % result)
            results[str(pathlib.Path(hypfile))] = result

    print_table(results)
