#!/usr/bin/env python

# argument parsing
import argparse

def getargparser():
    parser = argparse.ArgumentParser(description='Find the similarity between two short sentences using Word2Vec.')
    parser.add_argument('word2vec_modelpath', help='Path of the Word2Vec model')
    return parser

parser = getargparser()
args = parser.parse_args()

import shorttext
from shorttext.utils import tokenize
from scipy.spatial import distance
from itertools import product
import numpy as np

class ShortSentenceWord2VecSimilarity:
    def __init__(self, modelpath):
        self.model = shorttext.utils.load_word2vec_model(modelpath) if modelpath!=None else None

    def sim_words(self, word1, word2):
        return 1-distance.cosine(self.model[word1], self.model[word2])

    def jaccardscore_sents(self, sent1, sent2):
        tokens1 = tokenize(sent1)
        tokens2 = tokenize(sent2)
        tokens1 = filter(lambda w: w in self.model, tokens1)
        tokens2 = filter(lambda w: w in self.model, tokens2)
        allowable1 = [True]*len(tokens1)
        allowable2 = [True]*len(tokens2)

        simdict = dict()
        for i, j in product(range(len(tokens1)), range(len(tokens2))):
            simdict[(i, j)] = self.sim_words(tokens1[i], tokens2[j])

        intersection = 0.0
        simdictitems = sorted(simdict.items(), key=lambda s: s[1], reverse=True)
        for idxtuple, sim in simdictitems:
            i, j = idxtuple
            if allowable1[i] and allowable2[j]:
                intersection += sim
                allowable1[i] = False
                allowable2[j] = False

        union = len(tokens1) + len(tokens2) - intersection

        if union > 0:
            return intersection / union
        elif intersection == 0:
            return 1.
        else:
            return np.inf

if __name__ == '__main__':
    calculator = ShortSentenceWord2VecSimilarity(args.word2vec_modelpath)
    end = False
    # preload tokenizer
    tokenize('Mogu is cute.')
    while not end:
        sent1 = raw_input('sent1> ')
        if len(sent1)==0:
            end = True
        else:
            sent2 = raw_input('sent2> ')
            print "Word2Vec Jaccard Score Similarity = ", calculator.jaccardscore_sents(sent1, sent2)