#!/usr/bin/env python

# CoreTracker Copyright (C) 2016  Emmanuel Noutahi
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from collections import defaultdict as ddict
from Bio import SeqIO
from ete3 import Tree
from yaml import load as yamlload
from coretracker import __author__, __version__, date
from coretracker.classifier import Classifier, read_from_json
from coretracker.classifier import MODELPATH
from coretracker.classifier.models import ModelType
from coretracker.settings import Settings
from coretracker.coreutils import *
import time
import psutil
import argparse
import glob
import logging
import os

process = psutil.Process(os.getpid())
ENABLE_PAR = True
CPU_COUNT = 0
try:
    from multiprocessing import cpu_count
    CPU_COUNT = cpu_count()
    from joblib import Parallel, delayed
except ImportError:
    try:
        from sklearn.externals.joblib import Parallel, delayed
    except:
        ENABLE_PAR = False

try:
    from yaml import CLoader as Loader
except ImportError:
    from yaml import Loader

etiquette = ["fitch", "suspected", "Fisher pval", "Gene frac",
             "N. rea", "N. used", "Cod. count", "Sub. count",
             "G. len", "codon_lik", "N. mixte", "id"]  # , 'total_aa']


def AlignmentProgramType(program):
    # Raise a value error if not conform to expected programs
    tmp = program.strip().split()[0]
    if tmp not in ['muscle', 'mafft']:
        raise argparse.ArgumentTypeError(
            "Accepted alignment programs are mafft and muscle")
    else:
        return program


def memory_used():
    mem = 0
    try:
        mem = process.get_memory_info().rss/(1024.0*1024)
    except:
        mem = process.memory_info().rss/(1024.0*1024)
    return mem


def set_coretracker(args, settings):
    """Set all data for coretracker from the argument list"""
    # Check mafft command input
    def progcmd(x): return x + ' --auto' if x == 'mafft' else x

    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
    if args.outdir:
        if not os.path.exists(args.outdir):
            os.makedirs(args.outdir)
            # let original error handling
        settings.OUTDIR = args.outdir

    msaprg = progcmd(args.align)
    input_alignment = args.seq
    use_tree = None
    specietree = Tree(args.tree)
    if args.usetree:
        if args.align != 'mafft':
            logging.warning("Cannot usetree with muscle alignment")
        else:
            use_tree = args.tree
        settings.SCALE = args.use_tree

    hmmfiles = {}
    if args.hmmdir:
        try:
            hfiles = glob.glob(os.path.join(args.hmmdir, '*'))
            for f in hfiles:
                genename = os.path.basename(f).split('.hmm')[0]
                hmmfiles[genename] = f
        except:
            pass
    codemap = {}
    if args.codemap:
        with open(args.codemap) as CMAP:
            for line in CMAP:
                line = line.strip()
                if line and not line.startswith('#'):
                    spec, code = line.split()
                    codemap[spec] = int(code)

    spec_restricted_list = []
    if args.restrict_to:
        try:
            with open(args.restrict_to) as SREST:
                for line in SREST:
                    line = line.strip()
                    if line and not line.startswith('#'):
                        spec_restricted_list.append(line)
        except IOError:
            spec_restricted_list.append(args.restrict_to)

    clf = Classifier.load_from_file(MODELPATH % settings.MODEL_TYPE)
    model = ModelType(settings.MODEL_TYPE, etiquette)

    if clf is None or not clf.trained:
        raise ValueError("Classifier not found or not trained!")
    t = time.time()
    seqloader = SequenceLoader(input_alignment, args.dnaseq, settings, args.gapfilter,
                               use_tree=use_tree, refine_alignment=args.refine, msaprog=msaprg, hmmdict=hmmfiles)
    t2 = time.time()
    logging.info("Sequence loaded in %.3f s, using %.3f MB" %
                 (abs(t2 - t), memory_used()))
    # create sequence set
    setseq = SequenceSet(seqloader, specietree,
                         settings.GENETIC_CODE, codemap=codemap)
    # won't filter if their are id, gap and ic are not set
    # which is the desired action
    setseq.filter_sequences(args.idfilter, args.gapfilter,
                            args.iccontent, args.rmconst, nofilter=args.nofilter)
    spec_restricted_list = list(
        set(spec_restricted_list).intersection(setseq.common_genome))
    if spec_restricted_list:
        settings.LIMIT_TO_SUSPECTED_SPECIES = True
    t = time.time()
    logging.info("Sequence filtered in %.3f s, using %.3f MB" %
                 (abs(t - t2),  memory_used()))
    reafinder = ReaGenomeFinder(setseq, settings)
    reafinder.get_genomes()
    reafinder.possible_aa_reassignation(speclist=spec_restricted_list)
    t2 = time.time()
    logging.info("List of probable reassignment identified in %f s" %
                 (abs(t - t2)))
    return reafinder, clf, model


def compile_result(x, clf, cod_align, model):
    """compile result from analysis"""
    reafinder, fitch, data = x
    s_complete_data = utils.makehash()
    s_complete_data['aa'][fitch.ori_aa1][fitch.dest_aa1] = data
    s_complete_data['genome'] = reafinder.reassignment_mapper['genome']
    X_data, X_labels, _ = read_from_json(
        s_complete_data, None, use_global=reafinder.settings.USE_GLOBAL)
    # extract usefull features
    if X_data is not None and X_data.size:
        X_data, X_dataprint, selected_et = model.format_data(X_data)
        pred_prob = clf.predict_proba(X_data)
        #pred = clf.predict(X_data)
        pred = pred_prob.argmax(axis=1)
        if sum(pred) == 0 and reafinder.settings.SKIP_EMPTY:
            return None
        sppval, outdir, rkp, codvalid = utils.get_report(
            fitch, reafinder, cod_align, (X_data, X_labels, pred_prob, pred))
        utils.print_data_to_txt(os.path.join(outdir, fitch.ori_aa + "_to_" + fitch.dest_aa + "_data.txt"),
                                selected_et, X_dataprint, X_labels, pred, pred_prob, sppval, fitch.dest_aa, valid=codvalid)

        tmp_data = [X_labels, pred, pred_prob, codvalid]
        del X_data
        del X_dataprint
        del s_complete_data
        return rkp, tmp_data
    else:
        return None


if __name__ == '__main__':

    # argument parser
    parser = argparse.ArgumentParser(
        description='CoreTracker, A codon reassignment tracker')

    parser.add_argument(
        '--wdir', '--outdir', dest="outdir", default="output", help="Working directory")

    parser.add_argument('--version', action='version',
                        version='CoreTracker v.%s' % __version__)

    parser.add_argument('--gapfilter', '--gap', type=float, default=0.6, dest='gapfilter',
                        help="Remove position with more than `gapfilter` gap from the alignment, using gapfilter as threshold (default :0.6)")

    parser.add_argument('--idfilter', '--id', type=float, default=0.5, dest='idfilter',
                        help="Conserve only position with at least `idfilter` residue identity (default : 0.5)")

    parser.add_argument('--icfilter', '--ic', type=float, default=0.5, dest='iccontent',
                        help="Shannon entropy threshold (default : 0.5 ). This will be used to discard column where IC < max(icvector)*icfilter)")

    parser.add_argument('-t', '--intree', dest="tree",
                        help='Input specietree in newick format', required=True)

    parser.add_argument('--protseq', '--prot', '-p', dest='seq',
                        help="Protein sequence input in core format", required=True)

    parser.add_argument('--dnaseq', '--dna', '-n', dest='dnaseq',
                        help="Nucleotides sequences input in core format", required=True)

    parser.add_argument('--debug', dest='debug', action='store_true',
                        help="Enable debug printing")

    parser.add_argument('--rmconst', dest='rmconst', action='store_true',
                        help="Remove constant site from filtered alignment. ")

    parser.add_argument('--norefine', dest='refine', action='store_false',
                        help="Do not refine the alignment. By default the alignment will be refined. This option should never be used if you have made your own multiple alignment and concatenate it. Else you will have absurd alignment (TO FIX)")

    parser.add_argument('--novalid', dest='valid', action='store_false',
                        help="Do not validate prediction by retranslating and checking alignment improvement (Faster)")

    parser.add_argument('--nofilter', '--nf', action="store_true", dest="nofilter",
                        help="Do not filter sequence alignment.")

    parser.add_argument('--align', dest='align', nargs='?', type=AlignmentProgramType, const="muscle",
                        help="Choose a program to align your sequences")

    parser.add_argument('--use_tree', dest='usetree', nargs='?', const=1.0,
                        help="This is helpfull only if the mafft alignment is selected. Perform multiple alignment, using species tree as guide tree. A scaling value is needed to compute the branch format for the '--treein' option of mafft. If you're unsure, use the default value. The tree must be rooted and binary. See http://mafft.cbrc.jp/alignment/software/treein.html")

    parser.add_argument('--expos', '--export_position', dest='expos', action="store_true",
                        help="Export a json file with the position of each reassignment in the corresponding genome.")

    parser.add_argument('--hmmdir', dest='hmmdir',
                        help="Link a directory with hmm files for alignment. Each hmmfile should be named in the following format : genename.hmm")

    parser.add_argument('--params', dest='params',
                        help="Use A parameter file to load parameters. If a parameter is not set, the default will be used")

    parser.add_argument('--parallel', dest='parallel', nargs='?', const=CPU_COUNT, type=int, default=0,
                        help="Use Parallelization during execution for each reassignment. This does not guarantee an increase in speed. CPU count will be used if no argument is provided")

    parser.add_argument('--imformat', dest='imformat', choices=('pdf', 'png', 'svg'), default="pdf",
                        help="Image format to use for output (Codon_data file)")

    parser.add_argument('--restrict_to', dest='restrict_to',
                        help="Restrict analysis to list of genomes only")

    parser.add_argument('--codemap', dest='codemap',
                        help="A tab delimited file that map each species to its genetic code. Default code (--gcode) will be used for remaining species")

    print("CoreTracker v:%s Copyright (C) %s %s" %
          (__version__, date, __author__))

    args = parser.parse_args()
    start_t = time.time()
    if args.nofilter:
        # set the filter parameters to none
        args.idfilter = None
        args.gapfilter = None
        args.iccontent = None

    setting = Settings()

    paramfile = args.params
    ad_params = {}
    if paramfile:
        with open(paramfile) as f:
            ad_params = yamlload(f, Loader=Loader)

    # setting.set(OUTDIR=args.outdir)  # this does nothing
    setting.fill(ad_params)
    setting.update_params(COMPUTE_POS=args.expos)
    setting.update_params(VALIDATION=args.valid)
    setting.update_params(IMAGE_FORMAT=args.imformat)

    reafinder, clf, model = set_coretracker(args, setting)
    codon_align, fcodon_align = reafinder.seqset.get_codon_alignment()
    cod_align = SeqIO.to_dict(fcodon_align)
    reafinder.set_rea_mapper()

    # The program is run at this phase
    done = False
    results = []
    ALL_PRED = []

    if args.parallel > 0 and ENABLE_PAR:
        results = Parallel(n_jobs=args.parallel, verbose=1)(delayed(compile_result)(
            x, clf, cod_align, model) for x in reafinder.run_analysis(codon_align, fcodon_align))
        done = True
    elif args.parallel > 0:
        logging.warning(
            "Joblib requirement not found! Disabling parallelization")

    if not done:
        for x in reafinder.run_analysis(codon_align, fcodon_align):
            results.append(compile_result(x, clf, cod_align, model))
    # remove None results then unzip
    results = [r for r in results if r is not None]
    results, ALL_PRED = zip(*results)

    if args.valid and args.expos and results:
        rea_pos_keeper = ddict(dict)
        for r in results:
            for cuspec, readt in r.items():
                for k in readt.keys():
                    rea_pos_keeper[cuspec][k] = readt[k]
        exp_outfile = os.path.join(reafinder.settings.OUTDIR, "positions.json")
        reafinder.export_position(rea_pos_keeper, exp_outfile)

    reafinder.save_all(ALL_PRED, True, nofilter=args.nofilter)
    logging.info("\n**END (%.3f s, %.3f MB)" %
                 (abs(time.time() - start_t),  memory_used()))
