#!/usr/bin/env python
# coding: utf-8

"""
   File Name: hdidx_build
      Author: Wan Ji
      E-mail: wanji@live.com
  Created on: Sat Jul 25 10:26:19 2015 CST
"""
DESCRIPTION = """
"""

import os
import argparse
import logging

import numpy as np

from hdidx import indexer
from hdidx.util import Reader, normalize, HDIdxException, DO_NORM


def runcmd(cmd):
    """ Run command.
    """

    logging.info("%s" % cmd)
    os.system(cmd)


def getargs():
    """ Parse program arguments.
    """

    parser = argparse.ArgumentParser(description=DESCRIPTION,
                                     formatter_class=
                                     argparse.RawTextHelpFormatter)
    parser.add_argument('imglst', type=str,
                        help='id of images for training')
    parser.add_argument('featdir', type=str,
                        help='feature directory')
    parser.add_argument('infopath', type=str,
                        help='indexer information')
    parser.add_argument('nsubq', type=int, nargs='?', default=16,
                        help='number of sub quantizer')
    parser.add_argument('--log', type=str, nargs="?", default="INFO",
                        help='overwrite the temporal files')
    parser.add_argument('--dist', type=str, nargs="?", default="euclidean",
                        help='distance: euclidean, cosine')

    return parser.parse_args()


def main(args):
    """ Main entry.
    """

    dist_type = args.dist.lower()
    if dist_type not in DO_NORM:
        raise HDIdxException("Unrecognized distance type: %s" % dist_type)
    do_norm = DO_NORM[dist_type]

    logging.info("Loading image list...")
    imglst = np.loadtxt(args.imglst, np.int32).reshape(-1)
    imglst.sort()
    num_feats = imglst.shape[0]
    logging.info("\tDone!")

    logging.info("Loading first batch...")
    reader = Reader(args.featdir)
    feat = reader.get_next()
    logging.info("\tDone!")

    logging.info("Allocating memory for feat...")
    allfeat = np.zeros((num_feats, feat.shape[1]), feat.dtype)
    logging.info("\tDone!")

    allfeat[:feat.shape[0], :] = feat

    shift = 0
    logging.info("Loading features...")
    for i in xrange(num_feats):
        if i % 10000 == 0:
            logging.info("\t%d/%d" % (i, num_feats))
        idx = imglst[i] - shift
        while idx >= feat.shape[0]:
            shift += feat.shape[0]
            feat = reader.get_next()
            idx = imglst[i] - shift

        allfeat[i, :] = feat[idx, :]
    logging.info("\t%d/%dDone!" % (num_feats, num_feats))

    if do_norm:
        allfeat = normalize(allfeat)

    idx = indexer.PQIndexer()
    logging.info("Building Index...\n")
    idx.build({
        'vals': allfeat,
        'nsubq': args.nsubq,
    })
    logging.info("\tDone!")

    idx.save(args.infopath)


if __name__ == '__main__':
    args = getargs()
    numeric_level = getattr(logging, args.log.upper(), None)
    if not isinstance(numeric_level, int):
        raise ValueError("Invalid log level: " + args.log)
    logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s",
                        level=numeric_level)
    main(args)
