#!/usr/bin/env python

import argparse
import os

from paras.domain_extraction.extract_residues import extract_stach_codes, stach_to_challis
import paras.data
import paras.data.train_and_test_data.structure_alignments
from paras.common.writers import write_tabular
from paras.domain_extraction.extract_adomains import find_adomains

ALIGNMENT_FILE = os.path.join(os.path.dirname(paras.data.train_and_test_data.structure_alignments.__file__), 'training_alignment.fasta')


def make_parser():
    parser = argparse.ArgumentParser(description="Extract active site residues from adenylation domains.")
    parser.add_argument('-i', '--input', required=True, type=str, help=".fasta file containing proteins to be analysed.")
    parser.add_argument('-o', '--output', required=True, type=str, help=".txt file which will hold the output")
    parser.add_argument('-c', '--challis', action='store_true', help='Include challis code extraction.')
    parser.add_argument('-s', '--stachelhaus', action='store_true', help='Include stachelhaus code extraction.')
    parser.add_argument('-a', '--active_site', action='store_true', help='Include whole active site extraction.')

    return parser


def get_residue_codes(args):
    out_domains = os.path.join(os.getcwd(), 'a_domains')
    if not os.path.exists(out_domains):
        os.mkdir(out_domains)
    adomain_sequences = find_adomains(args.input, out_domains)

    id_to_stachelhaus, id_to_34 = extract_stach_codes(adomain_sequences, ALIGNMENT_FILE)
    id_to_challis = stach_to_challis(id_to_stachelhaus)

    dictionaries = []
    labels = []

    if args.stachelhaus:
        dictionaries.append(id_to_stachelhaus)
        labels.append('Stachelhaus')

    if args.active_site:
        dictionaries.append(id_to_34)
        labels.append('Active Site')

    if args.challis:
        dictionaries.append(id_to_challis)
        labels.append("Challis")

    write_tabular(dictionaries, labels, args.output)


if __name__ == "__main__":
    parser = make_parser()
    args = parser.parse_args()
    get_residue_codes(args)

