#!/usr/bin/env python

# CoreTracker Copyright (C) 2016  Emmanuel Noutahi
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import argparse
from coretracker.coreutils import SequenceLoader, CoreFile
import coretracker.coreutils.utils as utils
import sys
import warnings
import os
import glob
import logging

warnings.filterwarnings("ignore")


# argument parser
parser = argparse.ArgumentParser(
    description='Translator, translate nucleotide alignment. You need to install hmmalign, muscle or mafft and biopython')
parser.add_argument('--outdir', '--wdir', dest="outdir",
                    default="trans_output", help="Working directory")
parser.add_argument('--nuc', '--input', '-n', dest='dnafile',
                    required=True, help="Nucleotide input file")
parser.add_argument('--prot', '-p', dest='protfile',
                    help="protein input file. Force the use of this version of the protein for each gene")
parser.add_argument('--gcode', type=int, default=1, dest='gcode',
                    help="Genetic code to use for translation. Default value is 1")
parser.add_argument('--codemap', dest='codemap',
                    help="A tab delimited file that map each species to its genetic code. The value of --gcode will be used for missing species")
parser.add_argument('--prog', default="mafft", dest='prog', choices=[
                    'mafft', 'muscle'], help="Genetic code to use for translation. Default value is 1")
parser.add_argument('--align', dest='align', action='store_true',
                    help="Whether we should align or not")
parser.add_argument('--refine', dest='refine', action='store_true',
                    help="Whether we should refine the alignment with hmm")
parser.add_argument('--noclean', dest='noclean', action='store_false',
                    help="Whether we should clean temporary files created or not")
parser.add_argument('--nostopstrip', dest='nostopstrip', action='store_false',
                    help="Do not strip stop codons from sequences")
parser.add_argument('--alignarg', dest='alignargs', default="",
                    help="Additional arguments for the alignment program (muscle or mafft).")
parser.add_argument('--notrans', dest='notrans', action='store_true',
                    help="Whether we should attempt to translate or not. This overrule --align")
parser.add_argument('--filter', dest='filter', action='store_true',
                    help="Filter nuc sequences to remove gene where frame-shifting occur")
parser.add_argument('--hmmdir', dest='hmmdir',
                    help="Link a directory with hmm files for alignment. Each hmmfile should be named in the following format : genename.hmm")
parser.add_argument('--count', '-c', dest="hcount", default=1, type=int,
                    help="Number of loops for the hmm build and align")
parser.add_argument('--verbose', '-v', dest='verbose', action='store_true',
                    help="Print verbose (debug purpose)")


args = parser.parse_args()

if args.verbose:
    logging.basicConfig(level=logging.DEBUG)

if not os.path.exists(args.outdir):
    utils.purge_directory(args.outdir)

core_inst = args.dnafile
gcode = args.gcode
align = args.align
codemap = {}
if args.codemap:
    with open(args.codemap) as CMAP:
        for line in CMAP:
            line = line.strip()
            if line and not line.startswith('#'):
                spec, code = line.split()
                codemap[spec] = int(code)

prog = 'muscle ' + args.alignargs
if args.prog == 'mafft':
    prog = "mafft " + (args.alignargs or "--auto")

hmmfiles = {}
if args.hmmdir:
    try:
        hfiles = glob.glob(os.path.join(args.hmmdir, '*'))
        for f in hfiles:
            genename = os.path.basename(f).split('.hmm')[0]
            hmmfiles[genename] = f
    except Exception as e:
        print e
        pass

if args.filter:
    # filter the sequences to remove genes where all the sequences do
    # not have a length that is a multiple of 3
    tmp_core_inst = CoreFile(core_inst, alphabet='nuc')
    # do not treat as missing sequence just remove the entire gene
    core_inst = {}
    gcount, total_gcount = 0, 0
    for gene, seqs in tmp_core_inst.items():
        good_data = []
        disp_data = []
        species_seen = []
        for seq in seqs:
            if seq.name in species_seen:
                raise ValueError(
                    "Species %s seen at least 2 times in gene %s. Please Correct" % (seq.name, gene))
            if len(seq.seq.ungap('-')) % 3 != 0:
                disp_data.append((seq.name, len(seq)))
            else:
                seq.seq = seq.seq.ungap('-')
                good_data.append(seq)
            species_seen.append(seq.name)
        total_gcount += 1
        if not disp_data:
            core_inst[gene] = good_data
            gcount += 1
            print("++> OK (%s)" % gene)
        else:
            print("--> Fail (%s), %d specs have frame-shifting : \n\t(%s)" %
                  (gene, len(disp_data), ", ".join(["%s: %d" % (x) for x in disp_data])))

    if not core_inst:
        sys.exit("After removing frameshifting, there isn't any remaining genes")

    print("======== Total of %d / %d genes rescued ========" %
          (gcount, total_gcount))

else:
    core_inst = CoreFile(core_inst, alphabet='nuc')

# proceed to translation  here
if not args.notrans:
    translated_prot = SequenceLoader.translate(core_inst, gcode, codemap)
    if not args.nostopstrip:
        translated_prot, core_inst = SequenceLoader.clean_stop(
            translated_prot, core_inst)

    # if a protfile is given, use this version of the protein
    if args.protfile:
        prot_core_inst = CoreFile(args.protfile, alphabet='prot')
        for gene, seqs in prot_core_inst.items():
            for recseq in seqs:
                trans_seq = [x for x in translated_prot[gene]
                             if x.id == recseq.id]
                nuc_seq = [x for x in core_inst[gene] if x.id == recseq.id]
                if not args.nostopstrip and recseq.seq.endswith('*'):
                    recseq.seq = recseq.seq[0:-1]
                if nuc_seq and len(nuc_seq[0].seq) == len(recseq.seq) * 3:
                    trans_seq[0].seq = recseq.seq
    alignment = {}
    if align:
        for (gene, seqs) in translated_prot.items():
            if len(seqs) > 1:
                al = SequenceLoader._align(seqs, prog, None, 1.0, args.outdir)
                if args.refine:
                    al = SequenceLoader._refine(al, 9999, args.outdir, loop=args.hcount,
                                                clean=args.noclean, hmmfile=hmmfiles.get(gene, None))
                alignment[gene] = al

        CoreFile.write_corefile(alignment, os.path.join(
            args.outdir, "prot_aligned.core"))
    else:
        CoreFile.write_corefile(
            translated_prot, os.path.join(args.outdir, "prot.core"))

# save new sequence file for coretracker
if args.filter:
    CoreFile.write_corefile(core_inst, os.path.join(args.outdir, "nuc.core"))
