#!python

import argparse
import os

from matplotlib import use
use('agg')

import gravityspy.ml.train_classifier as train_classifier
import pandas
import h5py

# Definite Command line arguments here

def parse_commandline():
    """Parse the options given on the command-line.
    """
    parser = argparse.ArgumentParser(description=
       "An examples commandline of how to obtain a model is given below: "
       "THEANO_FLAGS=mode=FAST_RUN,device=cuda,floatX=float32 trainmodel "
       "--path-to-trainingset='somedir' --number-of-classes='somenum'")
    parser.add_argument("--path-to-trainingset",
                        help="folder where labeled images live", default=None)
    parser.add_argument("--number-of-classes", type=int,
                        help="How many classes do you have", required=True)
    parser.add_argument("--trainingset-pickle-file",
                        help="folder where the entire pickled training set "
                             "will live. This pickle file should be read in "
                             "by pandas",
                              default=os.path.join('pickeleddata',
                                                   'trainingset.pkl')
                       )
    parser.add_argument("--image-order", nargs='+',
                        default=['0.5.png', '1.0.png', '2.0.png', '4.0.png'])
    parser.add_argument("--model-name",
                        help="what you would like to model filename to be",
                        default='multi_view_classifier.h5')
    parser.add_argument("--order-of-channels",
                        help="Are you running with Theano and "
                             "Tensorflow backend? If tensorflow "
                             "than you want channels_last and if "
                             "theano channels_first",
                        default='channels_last')
    parser.add_argument("--batch-size", type=int, default=500,
                        help="defines the batch size, 30 is a reasonable size")
    parser.add_argument("--nb-epoch", type=int, default=20,
                        help="defines the number of iterations, "
                        "130 is reasonable. You can set it to 100 or below, "
                        "if you have time concern for training.")
    parser.add_argument("--fraction-validation", type=float, default=0.125,
                        help="Perentage of trianing set to save for validation")
    parser.add_argument("--fraction-testing", type=float, default=0,
                        help="Percentage of training set to save for testing")
    parser.add_argument("--randomseed", type=int, default=1986,
                        help="Set random seed")
    parser.add_argument("--verbose", action="store_true", default=False,
                        help="Run in Verbose Mode")
    args = parser.parse_args()

    if (not args.path_to_trainingset) and (
        not os.path.isfile(args.trainingset_pickle_file)
        ):
        raise parser.error('If you are not providing a path to the '
                           'trainingset you must specify '
                           'a pickle file containing the already '
                           'pixelized training set')

    if args.path_to_trainingset:
        if not os.path.isdir(args.path_to_trainingset):
            raise parser.error('Training Set path does not exist.')

    return args

# Parse commandline
args = parse_commandline()

# Pixelate and pickle the traiing set images
if args.path_to_trainingset:
    data = train_classifier.pickle_trainingset(
        path_to_trainingset=args.path_to_trainingset,
        save_address=args.trainingset_pickle_file,
        verbose=args.verbose
        )
else:
    data = pandas.read_pickle(args.trainingset_pickle_file)

# Check if teting percentage is 0, set to None
if not args.fraction_testing:
    fraction_testing = None
else:
    fraction_testing = args.fraction_testing

# Train model
class_names = sorted(data.true_label.unique())
class_names = [n.encode("ascii", "ignore") for n in class_names]

# Train model
model = train_classifier.make_model(
    data=data,
    order_of_channels=args.order_of_channels,
    image_order=args.image_order,
    batch_size=args.batch_size,
    nb_epoch=args.nb_epoch,
    nb_classes=args.number_of_classes,
    fraction_validation=args.fraction_validation,
    fraction_testing=fraction_testing,
    best_model_based_validset=0,
    image_size=[140, 170],
    random_seed=args.randomseed,
    verbose=True
    )

model.save(args.model_name)

f = h5py.File(args.model_name, 'r+')
grp = f.create_group('labels')
grp.create_dataset('labels', (len(class_names), 1), 'S100', class_names)
f.close()
