#!/usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import json
import pathlib
import argparse
from collections import OrderedDict

import numpy as np

from nmtpytorch.vocabulary import Vocabulary
from nmtpytorch.utils.misc import pbar


def freqs_to_dict(token_freqs, min_freq=0, max_items=0, exclude_symbols=False):
    # Get list of tokens
    tokens = list(token_freqs.keys())

    # Collect their frequencies in a numpy array
    freqs = np.array(list(token_freqs.values()))

    tokendict = OrderedDict()
    if not exclude_symbols:
        for key, value in Vocabulary.TOKENS.items():
            # Second value is the count information
            tokendict[key] = "{} 0".format(value)

    # Sort in descending order of frequency
    sorted_idx = np.argsort(freqs)
    if min_freq > 0:
        sorted_tokens = [(tokens[ii], freqs[ii]) for ii in sorted_idx[::-1]
                         if freqs[ii] >= min_freq]
    else:
        sorted_tokens = [(tokens[ii], freqs[ii]) for ii in sorted_idx[::-1]]

    if max_items > 0:
        sorted_tokens = sorted_tokens[:max_items]

    # Start inserting from index offset
    offset = len(tokendict)
    for iidx, (token, freq) in enumerate(sorted_tokens):
        tokendict[token] = '{} {}'.format(iidx + offset, int(freq))

    return tokendict


def get_freqs(filename, cumul_dict=None):
    # We'll first count frequencies
    if cumul_dict is not None:
        # Let's accumulate frequencies
        token_freqs = cumul_dict
    else:
        token_freqs = OrderedDict()

    print("Reading file %s" % filename)
    with open(filename) as fhandle:
        for line in pbar(fhandle, unit='lines'):
            line = line.strip()
            if line:
                # Collect frequencies
                for word in line.split(' '):
                    if word not in token_freqs:
                        token_freqs[word] = 0
                    token_freqs[word] += 1

    # Remove already available special tokens
    for key in Vocabulary.TOKENS:
        if key in token_freqs:
            print('Removing ', key)
            del token_freqs[key]

    return token_freqs


def write_dict(fname, vocab):
    print("Dumping vocabulary (%d tokens) to %s..." % (len(vocab), fname))
    with open(fname, 'w') as fhandle:
        json.dump(vocab, fhandle, ensure_ascii=False, indent=2)


def main():
    parser = argparse.ArgumentParser(prog='build-vocab')
    parser.add_argument('-o', '--output-dir', type=str, default='.',
                        help='Output directory')
    parser.add_argument('-s', '--single', type=str, default=None,
                        help='Name of the combined vocabulary file')
    parser.add_argument('-m', '--min-freq', type=int, default=0,
                        help='Filter out tokens occuring < m times')
    parser.add_argument('-M', '--max-items', type=int, default=0,
                        help='Keep the final vocabulary size less than this')
    parser.add_argument('-x', '--exclude-symbols', action='store_true',
                        help='Do not add special <eos>, <bos>, <pad>, <unk>')
    parser.add_argument('files', type=str, nargs='+',
                        help='Sentence files')
    args = parser.parse_args()

    if args.exclude_symbols:
        print('Warning: -x does not create vocabularies compatible '
              'with many nmtpytorch\'s models.')

    output_dir = pathlib.Path(args.output_dir).expanduser()

    # In case it is needed
    all_freqs = OrderedDict()

    for filename in args.files:
        filename = pathlib.Path(filename).expanduser()
        suffix = ".vocab{}".format(filename.suffix)
        vocab_fname = filename.stem

        if args.single:
            # Get cumulative frequencies
            all_freqs = get_freqs(filename, all_freqs)

        else:
            # Get frequencies
            freqs = get_freqs(filename)
            # Build dictionary from frequencies
            tokendict = freqs_to_dict(
                freqs, args.min_freq, args.max_items, args.exclude_symbols)

            if args.min_freq > 0:
                vocab_fname += "-min%d" % args.min_freq
            if args.max_items > 0:
                vocab_fname += "-max%dtokens" % args.max_items
            vocab_fname = str((output_dir / vocab_fname)) + suffix
            write_dict(vocab_fname, tokendict)

    if args.single:
        vocab_fname = pathlib.Path(args.single)
        tokendict = freqs_to_dict(
            all_freqs, args.min_freq, args.max_items, args.exclude_symbols)
        write_dict(vocab_fname, tokendict)


if __name__ == '__main__':
    sys.exit(main())
