#!/usr/bin/env python
# coding: utf-8

"""
   File Name: hdidx_index
      Author: Wan Ji
      E-mail: wanji@live.com
  Created on: Sat Jul 25 10:50:04 2015 CST
"""
DESCRIPTION = """
"""

import os
import argparse
import logging

import numpy as np

from hdidx import indexer
from hdidx.util import Reader, normalize, HDIdxException, DO_NORM


def runcmd(cmd):
    """ Run command.
    """

    logging.info("%s" % cmd)
    os.system(cmd)


def getargs():
    """ Parse program arguments.
    """

    parser = argparse.ArgumentParser(description=DESCRIPTION,
                                     formatter_class=
                                     argparse.RawTextHelpFormatter)
    parser.add_argument('imglst', type=str,
                        help='image list')
    parser.add_argument('featdir', type=str,
                        help='feature directory')
    parser.add_argument('infopath', type=str,
                        help='indexer information')
    parser.add_argument('idxpath', type=str,
                        help='index data')
    parser.add_argument('--log', type=str, nargs="?", default="INFO",
                        help='overwrite the temporal files')
    parser.add_argument('--dist', type=str, nargs="?", default="euclidean",
                        help='distance: euclidean, cosine')

    return parser.parse_args()


def main(args):
    """ Main entry.
    """

    dist_type = args.dist.lower()
    if dist_type not in DO_NORM:
        raise HDIdxException("Unrecognized distance type: %s" % dist_type)
    do_norm = DO_NORM[dist_type]

    logging.info("Loading image list...")
    imglst = np.loadtxt(args.imglst, np.int32).reshape(-1)
    imglst.sort()
    num_feats = imglst.shape[0]
    logging.info("\tDone!")

    idx = indexer.PQIndexer()
    logging.info("Loading indexer...")
    idx.load(args.infopath)
    logging.info("\tDone!")

    idx.set_storage('lmdb', {
        'path': args.idxpath,
        'clear': True,
    })

    logging.info("Loading first batch...")
    reader = Reader(args.featdir)
    feat = reader.get_next()
    logging.info("\tDone!")

    shift = 0
    keys = []

    logging.info("Indexing batches...")
    for iid in imglst:
        if iid - shift >= feat.shape[0]:
            if len(keys) > 0:
                if do_norm:
                    idx.add(normalize(feat[[key - shift for key in keys], :]), keys)
                else:
                    idx.add(feat[[key - shift for key in keys], :], keys)
                keys = []
            while iid - shift >= feat.shape[0]:
                shift += feat.shape[0]
                feat = reader.get_next()

        keys.append(iid)

    if len(keys) > 0:
        idx.add(feat[[key - shift for key in keys], :], keys)
    logging.info("\t%d/%d\tDone!" % (num_feats, num_feats))


if __name__ == '__main__':
    args = getargs()
    numeric_level = getattr(logging, args.log.upper(), None)
    if not isinstance(numeric_level, int):
        raise ValueError("Invalid log level: " + args.log)
    logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s",
                        level=numeric_level)
    main(args)
