#!python

from panoptes_client import *

import pandas as pd
import ast
import numpy as np
import os, sys
import ast
import pdb
import datetime
import collections
import operator

from sqlalchemy.engine import create_engine
from gravityspy.api.project import GravitySpyProject

engine = create_engine('postgresql://{0}:{1}@gravityspy.ciera.northwestern.edu:5432/gravityspy'.format(os.environ['GRAVITYSPY_DATABASE_USER'],os.environ['GRAVITYSPY_DATABASE_PASSWD']))

gspyproject = GravitySpyProject.load_project_from_cache('1104.pkl')
# Obtain number of classes from api
workflowDictSubjectSets = gspyproject.get_level_structure(IDfilter='O2')

# Must determine classes from dict
classes = sorted(workflowDictSubjectSets['2117'].keys())

# From ansers Dict determine number of classes
numClasses = len(classes)

# Flat retirement criteria
retired_thres = .9*np.ones(numClasses)

# Flat priors b/c we do not know what category the image is in
priors = np.ones((numClasses))/numClasses

# Load info about classifications and glitches
classifications = pd.read_sql('classificationsdev',engine)
classifications = classifications.loc[~(classifications.annotations_value_choiceINT == -1)]
glitches = pd.read_sql("glitches",engine)

# filter glitches for only testing images
glitches = glitches.loc[glitches.ImageStatus != 'Training']

# Merge DBs
combined_data = classifications.merge(glitches)

#Must start with earlier classifications and work way to new ones
combined_data.drop_duplicates(['links_subjects','links_user'],inplace=True)

# Create imageDB
columnsForImageDB = sorted(workflowDictSubjectSets['2117'].keys())
columnsForImageDB.extend(['uniqueID','links_subjects'])
image_db = combined_data[columnsForImageDB].drop_duplicates(['links_subjects'])
image_db.set_index(['links_subjects'],inplace=True)
image_db['numLabel'] = 0
image_db['retired'] = 0
image_db['finalScore'] = 0.0
image_db['finalLabel'] = ''

# Load confusion matrices
confusion_matrices = gspyproject.calculate_confusion_matrices() 

# Add posterior contribution from user to every image/subject
retired_db = pd.DataFrame()
def get_post_contribution(x):
    conf_divided = x.conf_matrix
    if not conf_divided.any():
        return
    # for every image they classifiy as a certain type, 
    # a users contribution to the posterior for that type is the same 
    # for every image. Therefore, it is in our interest to pre-
    # compute all of these values.
    post_contribution = conf_divided/np.sum(conf_divided, axis=1)
    # Determine none nan columns
    annotations = np.unique(np.argwhere(~np.isnan(post_contribution))[:,0])
    # Find all images for which the user gave one of these annotations
    imagesUserLabeled = combined_data.loc[(combined_data.links_user == x.userID) & (combined_data.annotations_value_choiceINT.isin(annotations)), ['links_subjects', 'annotations_value_choiceINT']]
    userLabeledSubjectIDs = imagesUserLabeled.links_subjects.values
    # with this info we can create an array of labels and therefore create a
    # matrix of the posterior contribution that user gave to all the images
    # they labeled
    rows = imagesUserLabeled.annotations_value_choiceINT.as_matrix()
    posteriorToAdd = post_contribution[rows, :]
    # In image DB find the relevant images and add the posterior contribution
    image_db.loc[userLabeledSubjectIDs, classes] = image_db.loc[userLabeledSubjectIDs, classes].add(posteriorToAdd, axis=1)
    # add 1 to numLabels for all images
    image_db.loc[userLabeledSubjectIDs, 'numLabel'] = image_db.loc[userLabeledSubjectIDs, 'numLabel'] + 1
    # check if we have more than 1 label for an image and check it
    # for retirement
    tmpDB = image_db.loc[userLabeledSubjectIDs][(image_db.loc[userLabeledSubjectIDs, 'numLabel'] >= 2) & (image_db.loc[userLabeledSubjectIDs, 'retired'] == 0)]
    if not tmpDB.empty:
        # Check if posterior is above threshold
        posterior = tmpDB[classes].divide(tmpDB['numLabel'].values +1 , axis=0)
        if (posterior > retired_thres).any().any():
            retiredIDs = tmpDB.loc[(posterior > retired_thres).any(axis=1)].index
            image_db.loc[retiredIDs, 'finalScore'] = posterior.loc[retiredIDs][classes].max(1).values
            image_db.loc[retiredIDs, 'finalLabel'] = posterior.loc[retiredIDs][classes].idxmax(1).values 
            image_db.loc[retiredIDs, 'retired'] = 1


confusion_matrices.apply(get_post_contribution, axis=1)

image_db[['finalScore', 'finalLabel', 'retired', 'uniqueID']].loc[image_db.retired == 1].to_sql('retiredimages',engine, if_exists='replace')
