#!/usr/bin/env python

import os
from joblib import load
from sys import argv

from paras.domain_extraction.extract_residues import extract_stach_codes
import paras.data
import paras.data.train_and_test_data.structure_alignments
from paras.feature_extraction.get_sequence_features import get_features, to_feature_vectors
from paras.domain_extraction.extract_adomains import find_adomains
from paras.common.parsers import parse_specificities
from paras.common.fasta import read_fasta

SPECIFICITIES_FILE = os.path.join(os.path.dirname(paras.data.__file__), 'specificities.txt')
SPECIFICITIES = parse_specificities(SPECIFICITIES_FILE)


def run_paras(sequence_file, alignment_file, classifier_file, strict=True):
    out_adomains = '.'
    adomain_sequences = find_adomains(sequence_file, out_adomains)

    id_to_stach, id_to_34 = extract_stach_codes(adomain_sequences, alignment_file)
    print(len(id_to_34))
    if not strict:
        train_specs = set()
        id_to_trainseq = read_fasta(alignment_file)
        for seq_id in id_to_trainseq:
            train_specs.add(SPECIFICITIES[seq_id])
        seqs_to_delete = []
        for seq_id in id_to_34:
            edited_id = seq_id.split('|')[0]
            if SPECIFICITIES[edited_id] not in train_specs:
                seqs_to_delete.append(seq_id)
        for seq_id in seqs_to_delete:
            del id_to_34[seq_id]

    print(len(id_to_34))

    id_to_features = get_features(id_to_34)
    ids, feature_vectors = to_feature_vectors(id_to_features)
    true_vals = []
    for seq_id in ids:
        edited_id = seq_id.split('|')[0]
        true_vals.append(SPECIFICITIES[edited_id])
    classifier = load(classifier_file)

    score = classifier.score(feature_vectors, true_vals)
    print(score)


if __name__ == "__main__":
    sequence_file = argv[1]
    alignment_file = argv[2]
    classifier_file = argv[3]
    strict = bool(int(argv[4]))
    print(strict)

    run_paras(sequence_file, alignment_file, classifier_file, strict=strict)