#!/usr/bin/env python

# argument parsing
import argparse

def get_argparser():
    argparser = argparse.ArgumentParser(description='Perform prediction on short text.with a given trained model.')
    argparser.add_argument('input_nameprefix', help='Prefix of the path of input model.')
    argparser.add_argument('algo', help='Algorithm architecture. (Options: sumword2vec (Summed Embedded Vectors), vnn (Neural Network on Embedded Vectors)')
    argparser.add_argument('wvmodel_path', help='Path of the pre-trained Word2Vec model.')
    argparser.add_argument('--ngram', type=int, default=2, help='n-gram, used in convolutional neural network only. (Default: 2)')
    argparser.add_argument('--test', action='store_true', default=False, help='Checked if the test input contains the label, and metrics will be output. (Default: False)')
    argparser.add_argument('--topn', type=int, default=10, help='Number of top-scored results displayed. (Default: 10)')
    return argparser

argparser = get_argparser()
args = argparser.parse_args()

# library loading
import os

import shorttext.classifiers.embed.nnlib.VarNNEmbedVecClassification as vnn
from shorttext.utils.wordembed import load_word2vec_model
import shorttext.classifiers.embed.sumvec.SumEmbedVecClassification as sumwv
from shorttext.classifiers import allowed_algos
from shorttext.utils.classification_exceptions import Word2VecModelNotExistException, AlgorithmNotExistException

# main block
if __name__ == '__main__':
    # check validity
    if not os.path.exists(args.wvmodel_path):
        raise Word2VecModelNotExistException(args.wvmodel_path)
    if not (args.algo in allowed_algos):
        raise AlgorithmNotExistException(args.algo)

    # load models
    print "Loading Word Embedding model..."
    wvmodel = load_word2vec_model(args.wvmodel_path)

    # initialize instance
    print "Instantiating classifier..."
    if args.algo=='vnn':
        classifier = vnn.VarNNEmbeddedVecClassifier(wvmodel)
    elif args.algo=='sumword2vec':
        classifier = sumwv.SumEmbeddedVecClassifier(wvmodel)
    else:
        raise AlgorithmNotExistException(args.algo)

    # load model
    print "Loading model..."
    classifier.loadmodel(args.input_nameprefix)

    # Console
    run = True
    while run:
        shorttext = raw_input('text> ')
        if len(shorttext) > 0:
            scoredict = classifier.score(shorttext)
            for label, score in sorted(scoredict.items(), key=lambda s: s[1], reverse=True)[:args.topn]:
                print label, ' : ', score
        else:
            run = False

    print "Done."
