#!/usr/bin/env python

# argument parsing
import argparse

def get_argparser():
    argparser = argparse.ArgumentParser(description='Perform prediction on short text with a given trained model.')
    argparser.add_argument('model_filepath', help='Path of the trained (compact) model.')
    argparser.add_argument('--wv', default='', help='Path of the pre-trained Word2Vec model. (None if not needed)')
    argparser.add_argument('--topn', type=int, default=10, help='Number of top-scored results displayed. (Default: 10)')
    return argparser

argparser = get_argparser()
args = argparser.parse_args()

allowed_classifiers = ['ldatopic', 'lsitopic', 'rptopic', 'kerasautoencoder', 'topic_sklearn',
                       'nnlibvec', 'sumvec', 'maxent']
needembedded_classifiers = ['nnlibvec', 'sumvec']
topicmodels = ['ldatopic', 'lsitopic', 'rptopic', 'kerasautoencoder']

# library loading
import os

import shorttext
from shorttext.utils.classification_exceptions import Word2VecModelNotExistException, AlgorithmNotExistException

# main block
if __name__ == '__main__':
    # check if the model file is given
    if not os.path.exists(args.model_filepath):
        raise IOError('Model file '+args.model_filepath+' not found!')

    # get the name of the classifier
    print 'Retrieving classifier name...'
    classifier_name = shorttext.utils.compactmodel_io.get_model_classifier_name(args.model_filepath)
    if not (classifier_name in allowed_classifiers):
        raise AlgorithmNotExistException(classifier_name)

    # load the Word2Vec model if necessary
    wvmodel = None
    if classifier_name in needembedded_classifiers:
        # check if thw word embedding model is available
        if not os.path.exists(args.wv):
            raise Word2VecModelNotExistException(args.wv)
        # if there, load it
        print 'Loading word-embedding model...', args.wv
        wvmodel = shorttext.utils.load_word2vec_model(args.wv)

    # load the classifier
    print 'Initializing the classifier...'
    classifier = None
    if classifier_name in topicmodels:
        topicmodel = shorttext.smartload_compact_model(args.model_filepath, wvmodel)
        classifier = shorttext.classifiers.TopicVectorCosineDistanceClassifier(topicmodel)
    else:
        classifier = shorttext.smartload_compact_model(args.model_filepath, wvmodel)

    # Initializing the SpaCy kernel
    shorttext.utils.textpreprocessing.spaCyNLPHolder.getNLPInstance()

    # Console
    run = True
    while run:
        shorttext = raw_input('text> ')
        if len(shorttext) > 0:
            scoredict = classifier.score(shorttext)
            for label, score in sorted(scoredict.items(), key=lambda s: s[1], reverse=True)[:args.topn]:
                print label, ' : ', score
        else:
            run = False

    print 'Done.'