#!/usr/bin/env python

import os
from joblib import load
from sys import argv

from paras.domain_extraction.extract_residues import extract_stach_codes
import paras.data
import paras.data.train_and_test_data.structure_alignments
from paras.feature_extraction.get_sequence_features import get_features, to_feature_vectors
from paras.domain_extraction.extract_adomains import find_adomains

ALIGNMENT_FILE = os.path.join(os.path.dirname(paras.data.train_and_test_data.structure_alignments.__file__), 'training_alignment.fasta')
CLASSIFIER_FILE = os.path.join(os.path.dirname(paras.data.__file__), 'paras.classifier')


def get_top_x_aa(amino_acid_classes, probabilities, x):
    probs_and_aa = []
    for i, probability in enumerate(probabilities):
        probs_and_aa.append((probability, amino_acid_classes[i]))

    probs_and_aa.sort(reverse=True)

    return probs_and_aa[:x]


def run_paras(sequence_file, out_file, get_probabilities=False, x=1):
    out_adomains = '.'
    adomain_sequences = find_adomains(sequence_file, out_adomains)

    id_to_stach, id_to_34 = extract_stach_codes(adomain_sequences, ALIGNMENT_FILE)
    id_to_features = get_features(id_to_34)
    ids, feature_vectors = to_feature_vectors(id_to_features)
    classifier = load(CLASSIFIER_FILE)
    predictions = classifier.predict(feature_vectors)
    probabilities = classifier.predict_proba(feature_vectors)
    amino_acid_classes = classifier.classes_

    if not get_probabilities:
        with open(out_file, 'w') as out:
            out.write('id\tprediction\n')
            for i, seq_id in enumerate(ids):
                prediction = predictions[i]
                out.write(f'{seq_id}\t{prediction}\n')

    else:
        with open(out_file, 'w') as out:
            out.write('id')
            for i in range(x):
                out.write(f'\tprediction_{i+1}\tprobability_{i+1}')
            out.write('\n')
            for i, seq_id in enumerate(ids):
                out.write(f'{seq_id}')
                probability_list = probabilities[i]
                probs_and_aa = get_top_x_aa(amino_acid_classes, probability_list, x)
                for j, prob_and_aa in enumerate(probs_and_aa):
                    prob, aa = prob_and_aa
                    out.write(f'\t{aa}\t{prob}')
                out.write('\n')


if __name__ == "__main__":
    fasta_file = argv[1]
    out_file = argv[2]

    run_paras(fasta_file, out_file, get_probabilities=True, x=int(argv[3]))
