#!python

import pandas
from gravityspy.ml import train_semantic_index
import os
import argparse

def parse_commandline():
    """Parse the arguments given on the command-line.
    """
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--path-to-trainingset",
                        help="folder where labeled images live", default=None)
    parser.add_argument("--trainingset-pickle-file",
                        help="folder where the entire pickled training set "
                             "will live. This pickle file should be read in "
                             "by pandas",
                              default=os.path.join('pickeleddata',
                                                   'rgb_trainingset.pkl')
                       )
    parser.add_argument("--model-name",
                        help="Save model filename",
                        default='similarity_model.h5')
    parser.add_argument("--batch_size", type=int,
                        help="Number of pairs per draw",
                        default=50)
    parser.add_argument("--training-steps-per-epoch", type=int,
                        help="Number of draws to do per epoch")
    parser.add_argument("--num-epoch", type=int,
                        help="Number of epochs")
    parser.add_argument("--verbose", action="store_true", default=False,
                        help="Run in Verbose Mode")
    args = parser.parse_args()

    if (not args.path_to_trainingset) and (
        not os.path.isfile(args.trainingset_pickle_file)
        ):
        raise parser.error('If you are not providing a path to the '
                           'trainingset you must specify '
                           'a pickle file containing the already '
                           'pixelized training set')

    if args.path_to_trainingset:
        if not os.path.isdir(args.path_to_trainingset):
            raise parser.error('Training Set path does not exist.')

    return args

args = parse_commandline()

# Pixelate and pickle the traiing set images
if args.path_to_trainingset:
    data = train_semantic_index.pickle_trainingset(
        path_to_trainingset=args.path_to_trainingset,
        save_address=args.trainingset_pickle_file,
        verbose=args.verbose
        )
else:
    data = pandas.read_pickle(args.trainingset_pickle_file)

semantic_idx_model, similarity_model = train_semantic_index.make_model(data,
                                                   nb_epoch=args.num_epoch,
                                                   order_of_channels="channels_last",
                                                   batch_size=args.batch_size,
                                                   training_steps_per_epoch=args.training_steps_per_epoch,
                                                   validation_steps_per_epoch=55,)

semantic_idx_model.save(args.model_name)
