#!/usr/bin/env python

# CoreTracker Copyright (C) 2016  Emmanuel Noutahi
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from collections import defaultdict as ddict
from coretracker import __author__, __version__, date
from coretracker.classifier import read_from_json
from coretracker.coreutils import *
from coretracker.settings import Settings
from coretracker.coreutils.letterconfig import aa_letters_3to1

from Bio import SeqIO
import argparse
import logging
import psutil
import time
import os
import yaml
import sys
process = psutil.Process(os.getpid())

ENABLE_PAR = True
CPU_COUNT = 0
try:
    from multiprocessing import cpu_count
    CPU_COUNT = cpu_count()
    from joblib import Parallel, delayed, dump, load
except ImportError:
    try:
        from sklearn.externals.joblib import Parallel, delayed, dump, load
    except:
        ENABLE_PAR = False

etiquette = ["fitch", "suspected", "Fisher pval", "Gene frac",
             "N. rea", "N. used", "Cod. count", "Sub. count",
             "G. len", "codon_lik", "N. mixte", "id"]  # , 'total_aa']


def memory_used():
    # Bad practice rewriting this method at each time
    mem = 0
    try:
        mem = process.get_memory_info().rss/(1024.0*1024)
    except:
        mem = process.memory_info().rss/(1024.0*1024)
    return mem


def compile_result(x, clf, cod_align, model):
    """compile result from analysis"""
    reafinder, fitch, data = x
    s_complete_data = utils.makehash()
    s_complete_data['aa'][fitch.ori_aa1][fitch.dest_aa1] = data
    s_complete_data['genome'] = reafinder.reassignment_mapper['genome']
    X_data, X_labels, _ = read_from_json(
        s_complete_data, None, use_global=reafinder.settings.USE_GLOBAL)
    # extract usefull features
    if X_data is not None and X_data.size:
        X_data, X_dataprint, selected_et = model.format_data(X_data)
        pred_prob = clf.predict_proba(X_data)
        #pred = clf.predict(X_data)
        pred = pred_prob.argmax(axis=1)
        if sum(pred) == 0 and reafinder.settings.SKIP_EMPTY:
            return None
        sppval, outdir, rkp, codvalid = utils.get_report(
            fitch, reafinder, cod_align, (X_data, X_labels, pred_prob, pred))
        utils.print_data_to_txt(os.path.join(outdir, fitch.ori_aa + "_to_" + fitch.dest_aa + "_data.txt"),
                                selected_et, X_dataprint, X_labels, pred, pred_prob, sppval, fitch.dest_aa, valid=codvalid)
        tmp_data = [X_labels, pred, pred_prob, codvalid]
        del X_data
        del X_dataprint
        del s_complete_data
        return rkp, tmp_data
    else:
        return None


if __name__ == '__main__':

    # argument parser
    parser = argparse.ArgumentParser(
        description='CoreTracker, A codon reassignment tracker')

    parser.add_argument(
        '--wdir', '--outdir', dest="outdir", default="output", help="Working directory")

    parser.add_argument(
        'input', help="Input should be a runnable instance returned by coretracker-prep")

    parser.add_argument('--novalid', dest='valid', action='store_false',
                        help="Do not validate prediction by retranslating and checking alignment improvement")

    parser.add_argument('--expos', '--export_position', dest='expos', action="store_true",
                        help="Export a json file with the position of each reassignment in the corresponding genome.")

    parser.add_argument('--imformat', dest='imformat', choices=('pdf', 'png', 'svg'), default="pdf",
                        help="Image format to use for output (Codon_data file)")

    parser.add_argument('--aapair', dest='aapair',
                        help="Use a list of potential reassignments (see coretracker-prep's '.aa' output for example.")

    parser.add_argument('--parallel', dest='parallel', nargs='?', const=CPU_COUNT, type=int, default=0,
                        help="Use Parallelization during execution for each reassignment. This does not guarantee an increase in speed. CPU count will be used if no argument is provided")

    parser.add_argument('--version', action='version',
                        version='coretracker-prep v.%s' % __version__)
    parser.add_argument('--debug', dest='debug', action='store_true',
                        help="Enable debug printing")

    print("CoreTracker v:%s Copyright (C) %s %s" %
          (__version__, date, __author__))

    args = parser.parse_args()
    start_t = time.time()
    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
    aasubset = None
    if args.aapair:
        try:
            with open(args.aapair) as f:
                aasubset = yaml.load(f)
                aasubset = [(aa_letters_3to1[k], aa_letters_3to1[vv])
                            for k, v in aasubset.items() for vv in v]
        except:
            print("Input provided with --aapair is not valid !")
            sys.exit(0)

    run_instance = load(args.input)
    run_instance.rfinder.settings.update_params(COMPUTE_POS=args.expos)
    run_instance.rfinder.settings.update_params(VALIDATION=args.valid)
    run_instance.rfinder.settings.update_params(IMAGE_FORMAT=args.imformat)

    if args.outdir:
        if not os.path.exists(args.outdir):
            os.makedirs(args.outdir)
            # let original error handling
        run_instance.rfinder.settings.update_params(OUTDIR=args.outdir)

    codon_align, fcodon_align = run_instance.rfinder.seqset.get_codon_alignment()
    cod_align = SeqIO.to_dict(fcodon_align)
    clf, model = run_instance.get_model(etiquette)
    reafinder = run_instance.rfinder
    nofilter = run_instance.args.nofilter
    del run_instance

    done = False
    results = []
    ALL_PRED = []
    if args.parallel > 0 and ENABLE_PAR:
        results = Parallel(n_jobs=args.parallel, verbose=1)(delayed(compile_result)(
            x, clf, cod_align, model) for x in reafinder.run_analysis(codon_align, fcodon_align, aasubset))
        done = True
    elif args.parallel > 0:
        logging.warning(
            "Joblib requirement not found! Disabling parallelization")

    if not done:
        for x in reafinder.run_analysis(codon_align, fcodon_align, aasubset):
            results.append(compile_result(x, clf, cod_align, model))
    # remove None results then unzip
    results = [r for r in results if r is not None]
    results, ALL_PRED = zip(*results)

    if args.valid and args.expos and results:
        rea_pos_keeper = ddict(dict)
        for r in results:
            for cuspec, readt in r.items():
                for k in readt.keys():
                    rea_pos_keeper[cuspec][k] = readt[k]
        exp_outfile = os.path.join(reafinder.settings.OUTDIR, "positions.json")
        reafinder.export_position(rea_pos_keeper, exp_outfile)

    reafinder.save_all(ALL_PRED, True, nofilter=nofilter)
    logging.info("\n**END (%.3f s, %.3f MB)" %
                 (abs(time.time() - start_t),  memory_used()))
