#!python

# ---- Import standard modules to the python path.

from gravityspy.api.project import GravitySpyProject
import gravityspy.ml.read_image as read_image
import gravityspy.ml.labelling_test_glitches as label_glitches
from gravityspy.utils import log
import argparse
import pandas as pd
import os
from sqlalchemy.engine import create_engine

def parse_commandline():
    """Parse the arguments given on the command-line.
    """
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--Filename1", help="Path To File1",
                        required=True, nargs='+')
    parser.add_argument("--Filename2", help="Path To File2",
                        required=True, nargs='+')
    parser.add_argument("--Filename3", help="Path To File 3",
                        required=True, nargs='+')
    parser.add_argument("--Filename4", help="Path to File 4",
                        required=True, nargs='+')
    parser.add_argument("--path-to-cnn-model",
                        help="Path to name of cnn model",
                        required=True)
    args = parser.parse_args()

    return args

args = parse_commandline()


logger = log.Logger('Gravity Spy: Relabel Images')
# load the api gravityspy project cached class
gspyproject = GravitySpyProject.load_project_from_cache('1104.pkl')
# Since we created the images in a special temporary directory we can run os.listdir to get there full
# names so we can convert the images into ML readable format.
list_of_images_all = [args.Filename1,
                      args.Filename2,
                      args.Filename3,
                      args.Filename4]

list_of_images_all = zip(list_of_images_all[0],list_of_images_all[1],list_of_images_all[2],list_of_images_all[3])

for list_of_images in list_of_images_all:
    ID = list_of_images[0].split('/')[-1].split('_')[1]

    logger.info('Converting image to ML readable...')

    image_dataDF = pd.DataFrame()
    for idx, image in enumerate(list_of_images):
        logger.info('Converting {0}'.format(image))
        image_data = read_image.read_grayscale(image,
                                      resolution=0.3)
        image_dataDF[image] = [image_data]

    image_dataDF['uniqueID'] = ID

    # Now label the image
    logger.info('Labelling image...')
    scores, MLlabel = label_glitches.label_glitches(image_data=image_dataDF,
                                          model_adr='{0}'.format(args.path_to_cnn_model),
                                          image_size=[140, 170],
                                          verbose=False)

    confidence = float(scores[0][MLlabel])

    # We must determine the columns that will be saved for this image.
    # First and foremost we want to determine the possible classes the image could be
    # get all the info about the workflows
    workflowDictSubjectSets = gspyproject.get_level_structure(IDfilter='O2')

    # Must determine classes from dict
    classes = sorted(workflowDictSubjectSets['2117'].keys())

    # Add on columns that are Gravity Spy specific
    classes.extend(["uniqueID","Label"])

    # Determine label
    Label = classes[MLlabel]
    logger.info('This image has received the following label: {0} with {1} percent confidence'.format(Label, confidence))

    # determine confidence values from ML
    scores = scores[0].tolist()
    # Append uniqueID to list so when we update sql we will know which entry to update
    scores.append(ID)
    # Append label
    scores.append(Label)

    scoresTable = pd.DataFrame([scores],columns=classes)

    engine = create_engine(
                           'postgresql://{0}:{1}'\
                           .format(os.environ['GRAVITYSPY_DATABASE_USER'],os.environ['GRAVITYSPY_DATABASE_PASSWD'])\
                           + '@gravityspy.ciera.northwestern.edu:5432/gravityspy')
    columnDict = scoresTable.to_dict(orient='records')[0]
    SQLCommand = 'UPDATE glitches SET '
    for Column in columnDict:
        if isinstance(columnDict[Column], str):
            SQLCommand = SQLCommand + '''\"{0}\" = \'{1}\', '''.format(Column,columnDict[Column])
        else:
            SQLCommand = SQLCommand + '''\"{0}\" = {1}, '''.format(Column,columnDict[Column])
    SQLCommand = SQLCommand[:-2] + ' WHERE \"uniqueID\" = \'' + scoresTable.uniqueID.iloc[0] + "'"
    engine.execute(SQLCommand)
    engine.dispose()
