#!python

"""Prep coding sequence alignments for ``phydms``.

Written by Jesse Bloom."""


import os
import re
import time
import subprocess
import matplotlib
matplotlib.use('pdf')
import matplotlib.pyplot as plt
plt.style.use('ggplot')
import Bio.SeqIO
import Bio.Data.CodonTable
import phydmslib.parsearguments
import phydmslib.file_io


_cleanheader = re.compile('[\s\:\;\(\)\[\]\,\'\"]')
def CleanHeader(head):
    """Returns copy of header with problem characters replaced by underscore."""
    return _cleanheader.sub('_', head)


def MAFFT_CDSandProtAlignments(headseqprots, mafftcmd, tempfileprefix):
    """Uses MAFFT to align CDS sequences from protein sequences.

    *headseqprots* is a list of *(header, ntseq, protseq)* strings.

    *mafftcmd* is command to run MAFFT.

    *tempfileprefix* prefix for temporary files that are created,
    then destroyed if function completes normally.

    The return value is *alignedheadseqprots*, which is like
    *headseqprots* except coding and protein sequences are aligned.
    The sequence order is preserved in this list.
    """
    seq_d = dict([(head, seq) for (head, seq, prot) in headseqprots])
    prot_d = dict([(head, prot) for (head, seq, prot) in headseqprots])

    (tempfiledir, tempfilebase) = (os.path.dirname(tempfileprefix), os.path.basename(tempfileprefix))
    if tempfiledir:
        tempfileprefix = tempfiledir + '/_' + tempfilebase
    else:
        tempfileprefix = tempfilebase
    unalignedprotsfile = tempfileprefix + '_unaligned_mafft_prots.fasta'
    alignedprotsfile = tempfileprefix + '_aligned_mafft_prots.fasta'
    outputfile = tempfileprefix + '_mafft_output.txt'
    with open(unalignedprotsfile, 'w') as f:
        f.write('\n'.join(['>{0}\n{1}'.format(head, prot) for (head, seq, prot) in headseqprots]))

    with open(alignedprotsfile,'w') as stdout, open(outputfile, 'w') as stderr:
        subprocess.check_call(
                [mafftcmd, '--auto', '--thread', '-1', unalignedprotsfile], 
                stdout=stdout, 
                stderr=stderr,
                )

    alignedseqprots_by_head = {}
    for alignedprotrecord in Bio.SeqIO.parse(alignedprotsfile, 'fasta'):
        head = alignedprotrecord.description
        alignedprot = str(alignedprotrecord.seq)
        prot = prot_d[head]
        seq = seq_d[head]
        assert ('-' not in prot) and ('-' not in seq), "Already gaps in seq or prot"
        assert len(alignedprot) >= len(prot)
        assert len(alignedprot.replace('-', '')) == len(prot)
        assert len(seq) == len(prot) * 3
        alignedseq = []
        i = 0 # keeps track of position in unaligned prot
        for aa in alignedprot:
            if aa == '-':
                alignedseq.append('---')
            else:
                assert aa == prot[i]
                alignedseq.append(seq[3 * i : 3 * i + 3])
                i += 1
        alignedseqprots_by_head[head] = (''.join(alignedseq), alignedprot)

    # get back to original input order
    alignedheadseqprots = []
    for (head, seq, prot) in headseqprots:
        (alignedseq, alignedprot) = alignedseqprots_by_head[head]
        alignedheadseqprots.append((head, alignedseq, alignedprot))

    for f in [unalignedprotsfile, alignedprotsfile, outputfile]:
        os.remove(f)

    return alignedheadseqprots


def main():
    """Main body of script."""

    # Parse command line arguments
    parser = phydmslib.parsearguments.PhyDMSPrepAlignmentParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    # print some basic information
    print('\nBeginning execution of %s in directory %s at time %s\n' % (prog, os.getcwd(), time.asctime()))
    print("%s\n" % phydmslib.file_io.Versions())
    print('Parsed the following command-line arguments:\n%s\n' % '\n'.join(['\t%s = %s' % tup for tup in args.items()]))

    plot = os.path.splitext(args['alignment'])[0] + '.pdf'
    for f in [plot, args['alignment']]:
        if os.path.isfile(f):
            print("Removing existing output file {0}".format(f))
            os.remove(f)

    # get keepseqs and purgeseqs as sets
    for argname in ['keepseqs', 'purgeseqs']:
        argvalue = args[argname]
        if argvalue:
            if len(argvalue) == 1 and argvalue[0] and os.path.isfile(argvalue[0]):
                with open(argvalue[0]) as f:
                    args[argname] = set([line.strip() for line in f.readlines() if line and not line.isspace()])
                print("Read {0} sequences from the --{1} file {2}".format(len(args[argname]), argname, argvalue[0]))
            else:
                args[argname] = set([x for x in argvalue if x])
                print("There are {0} sequences specified by --{1}".format(len(args[argname]), argname))
        else:
            args[argname] = set([])

    # read sequences
    seqs = []
    for seq in Bio.SeqIO.parse(args['inseqs'], 'fasta'):
        seq.seq = seq.seq.upper()
        seqs.append(seq)
    print("Read {0} sequences from {1}".format(len(seqs), args['inseqs']))

    # sequences to purge
    cleanseqs = []
    for seq in seqs:
        head = seq.description
        cleanhead = CleanHeader(head)
        for purgename in args['purgeseqs']:
            if (purgename in head) or (purgename in cleanhead):
                for keepname in args['keepseqs']:
                    if (keepname in head) or (keepname in cleanhead):
                        raise ValueError("--purgseqs and --keepseqs conflict about what to do with {0}".format(head))
                break
        else:
            cleanseqs.append(seq)
    seqs = cleanseqs
    print("Retained {0} after removing those specified for purging by '--purgeseqs.'".format(len(seqs)))

    # remove gaps if sequences not pre-aligned
    if not args['prealigned']:
        for (i, seq) in enumerate(seqs):
            seqs[i].seq = seqs[i].seq.ungap(gap='-')

    # check length divisible by 3
    seqs = [seq for seq in seqs if len(seq) % 3 == 0]
    print("Retained {0} sequences after removing any with length not multiple of 3.".format(len(seqs)))

    # purge ambiguous nucleotide sequences
    seqmatch = re.compile('^[ACTG-]+$')
    seqs = [seq for seq in seqs if seqmatch.search(str(seq.seq))]
    print("Retained {0} sequences after purging any with ambiguous nucleotides.".format(len(seqs)))

    # make sure all sequences translate without premature stops
    prots = []
    remove_indices = []
    for i in range(len(seqs)):
        try:
            prots.append(str(seqs[i].seq.translate(gap='-', stop_symbol='*')))
        except Bio.Data.CodonTable.TranslationError:
            remove_indices.append(i) # doesn't translate
    for i in sorted(remove_indices, reverse=True):
        del seqs[i]
    assert len(prots) == len(seqs)
    remove_indices = []
    if args['prealigned']:
        for i in range(len(seqs)):
            prot = prots[i]
            if (prot.count('*') == 1) and prot.replace('-', '')[-1] == '*':
                # convert terminal stop to gap, trimmed later if all sites gapped
                istop = prot.index('*')
                prot = list(prots[i])
                prot[istop] = '-'
                prots[i] = ''.join(prot)
                # introduce gap requires making seq mutable
                iseq = seqs[i].seq.tomutable()
                iseq[3 * istop : 3 * istop + 3] = '---'
                seqs[i].seq = iseq.toseq()
            elif prot.count('*') >= 1:
                # premature stop
                remove_indices.append(i)
    else:
        for i in range(len(seqs)):
            prot = prots[i]
            if prot.count('*') == 1 and prot[-1] == '*':
                # trim terminal stop
                prots[i] = prots[i][ : -1]
                seqs[i].seq = seqs[i].seq[ : -3]
            elif prot.count('*') >= 1:
                # premature stop
                remove_indices.append(i)
    for i in sorted(remove_indices, reverse=True):
        del prots[i]
        del seqs[i]
    print("Retained {0} sequences after purging any with premature stops or that are otherwise un-translateable.".format(len(seqs)))

    # get reference sequence
    refseq = [seq for seq in seqs if (args['refseq'] in seq.description) or (CleanHeader(args['refseq']) in seq.description)]
    if len(refseq) == 1:
        refseq = refseq[0]
        print("Using the following as reference sequence: {0}".format(refseq.description))
    elif len(refseq) < 1:
        raise ValueError("Failed to find any sequence with a header that contained refseq identifier of {0}\nWas reference sequence purged due to not being translate-able?".format(args['refseq']))
    else:
        raise ValueError("Found multiple sequences with headers that contained refseq identifier of {0}".format(args['refseqs']))

    # For removing sequences, we make two lists: sequences to always keeps,
    # then other sequences ordered by frequency that protein occurs.
    # This allows us to preferentially retain common sequences.
    seqs_alwayskeep = [] # entries are (head, seq, prot) as strings
    protcounts = {} # keyed by protein, values are lists of (head, seq, prot)
    foundrefseq = False
    assert len(seqs) == len(prots)
    for (seqrecord, prot) in zip(seqs, prots):
        head = seqrecord.description
        cleanhead = CleanHeader(head)
        seq = str(seqrecord.seq)
        if head == refseq.description:
            assert not foundrefseq, "Duplicate reference sequences"
            foundrefseq = True
            seqs_alwayskeep.insert(0, (cleanhead, seq, prot))
        else:
            for keepsubstring in args['keepseqs']:
                if (keepsubstring in head) or (keepsubstring in cleanhead):
                    seqs_alwayskeep.append((cleanhead, seq, prot))
                    break
            else:
                if prot in protcounts:
                    protcounts[prot].append((cleanhead, seq, prot))
                else:
                    protcounts[prot] = [(cleanhead, seq, prot)]
    assert len(seqs) == len(seqs_alwayskeep) + sum([len(x) for x in protcounts.values()]), "Somehow lost sequences when looking for unique seqs."
    assert foundrefseq, "Never found refseq to add to seqs_alwayskeep"
    # get one unique copy of each protein sorted by number of copies
    seqs_unique = sorted([(len(x), x[0]) for x in protcounts.values()], reverse=True)
    seqs_unique = [seqtup for (count, seqtup) in seqs_unique]
    print("Purged sequences encoding redundant proteins, being sure to retain reference sequence and any specified by '--keepseqs'. Overall, {0} sequences remain.".format(len(seqs_unique) + len(seqs_alwayskeep)))
    heads = [tup[0] for tup in seqs_unique] + [tup[0] for tup in seqs_alwayskeep]
    assert len(heads) == len(set(heads)), "Found sequences with non-unique cleaned headers:\n%s".format('\n'.join([head for head in heads if heads.count(head) > 1]))

    if args['prealigned']:
        print("You specified '--prealigned', so the sequences are NOT being aligned.")
        # make sure sequences all the same length
        seqlengths = [len(tup[1]) for tup in seqs_unique + seqs_alwayskeep]
        assert all([seqlengths[0] == n for n in seqlengths]), "You indicated sequences are pre-aligned with the '--prealigned option, but they are not all of the same length."
    else:
        # align all sequences
        try:
            mafftversion = subprocess.check_output([args['mafft'], '--version'], stderr=subprocess.STDOUT).strip()
            print("Aligning sequences with MAFFT using {0} {1}".format(args['mafft'], mafftversion))
        except OSError:
            raise ValueError("The mafft command of {0} is not valid. Is mafft installed at this path?".format(args['mafft']))
        alignment = MAFFT_CDSandProtAlignments(seqs_unique + seqs_alwayskeep, args['mafft'], args['alignment'])
        seqs_unique = alignment[ : len(seqs_unique)]
        seqs_alwayskeep = alignment[len(seqs_unique) : ]
        seqlengths = [len(tup[1]) for tup in seqs_unique + seqs_alwayskeep]
        assert all([seqlengths[0] == n for n in seqlengths]), "MAFFT alignment failed to yield sequences all of the same length."

    # strip gaps relative refseq
    assert seqs_alwayskeep[0][0] == CleanHeader(refseq.description), "Expected refseq to have this header:\t{0}\nFound this:\t{1}".format(refseq.description, seqs_alwayskeep[0][0])
    refseqprot = seqs_alwayskeep[0][2]
    gapped = set([i for i in range(len(refseqprot)) if refseqprot[i] == '-'])
    for xlist in [seqs_alwayskeep, seqs_unique]:
        for (i, (head, seq, prot)) in enumerate(xlist):
            strippedprot = ''.join([aa for (j, aa) in enumerate(prot) if j not in gapped])
            strippedseq = ''.join([seq[3 * j : 3 * j + 3] for j in range(len(prot)) if j not in gapped])
            xlist[i] = (head, strippedseq, strippedprot)
    refseqprot = seqs_alwayskeep[0][2] # with gaps removed
    assert '-' not in refseqprot, "Gap stripping from reference sequence failed."
    refseq = seqs_alwayskeep[0][1] # with gaps removed
    assert '-' not in refseq, "Gap stripping from reference sequence failed."
    print("After stripping gaps relative to reference sequence, all proteins are of length {0}".format(len(refseqprot)))

    # make seqs_unique, seq_alwayskeep be *(head, seq, prot, ntident, protident)*
    refseqprotlength = float(len(refseqprot))
    refseqlength = float(len(refseq))
    for xlist in [seqs_alwayskeep, seqs_unique]:
        for (i, (head, seq, prot)) in enumerate(xlist):
            protident = sum([x == y for (x, y) in zip(prot, refseqprot)]) / refseqprotlength
            ntident = sum([x == y for (x, y) in zip(seq, refseq)]) / refseqlength
            xlist[i] = (head, seq, prot, ntident, protident)

    # purge sequences that don't meet identity threshold
    purged_idents = [] # list of *(ntident, protident)* for purged sequences
    for xlist in [seqs_unique]:
        remove_indices = []
        for (i, (head, seq, prot, ntident, protident)) in enumerate(xlist):
            if protident < args['minidentity']:
                purged_idents.append((ntident, protident))
                remove_indices.append(i)
        for i in sorted(remove_indices, reverse=True):
            del xlist[i]
    print("Retained {0} sequences after purging those with < {1} protein identity to reference sequence.".format(len(seqs_alwayskeep + seqs_unique), args['minidentity']))

    # purge sequences that don't meet uniqueness threshold
    alignment = seqs_alwayskeep
    for (head, seq, prot, ntident, protident) in seqs_unique:
        for (head2, seq2, prot2, ntident2, protident2) in alignment:
            protdiffs = sum([x != y for (x, y) in zip(prot, prot2)])
            if protdiffs < args['minuniqueness']:
                purged_idents.append((ntident, protident))
                break
        else:
            alignment.append((head, seq, prot, ntident, protident))
    print("Retained {0} sequences after purging those without at least {1} amino-acid differences with other retained sequences.".format(len(alignment), args['minuniqueness']))

    # write final alignment of coding sequences
    print("Writing the final alignment of {0} coding sequences to {1}".format(len(alignment), args['alignment']))
    with open(args['alignment'], 'w') as f:
        for (head, seq, prot, ntident, protident) in alignment:
            f.write('>{0}\n{1}\n'.format(head, seq))

    # make plot
    retained_idents = [(ntident, protident) for (head, seq, prot, ntident, protident) in alignment]
    print("Plotting retained and purged sequences to {0}".format(plot))
    fig = plt.figure()
    points = []
    labels = []
    for (label, idents, color, marker, alpha) in [
            ('purged', purged_idents, 'b', 'o', 0.25), 
            ('retained', retained_idents, 'r', 'd', 0.5)
            ]:
        pt = plt.scatter(
                [tup[0] for tup in idents],
                [tup[1] for tup in idents],
                c=color,
                marker=marker,
                s=30,
                alpha=alpha)
        points.append(pt)
        labels.append(label)
    plt.xlabel('DNA identity relative to refseq', fontsize=14)
    plt.ylabel('protein identity relative to refseq', fontsize=14)
    plt.legend(points, labels, scatterpoints=1, loc='upper left', fontsize=14)
    plt.savefig(plot)

    # program done
    print('\nSuccessful completion of %s' % prog)


if __name__ == '__main__':
    main() # run the script
