#!/usr/bin/env python

# CoreTracker Copyright (C) 2016  Emmanuel Noutahi
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from collections import defaultdict as ddict
from coretracker import __author__, __version__, date
from coretracker.classifier import read_from_json
from coretracker.coreutils import *
from coretracker.settings import Settings
from coretracker.coreutils.letterconfig import aa_letters_3to1
from Bio import SeqIO
from functools import partial
import argparse
import logging
import psutil
import time
import os
import yaml
import resource
import gc
import sys
import multiprocessing as mp
process = psutil.Process(os.getpid())

CPU_COUNT = mp.cpu_count()

global spec_data
global cod_align

try:
    from joblib import dump, load
except ImportError:
    from sklearn.externals.joblib import dump, load

etiquette = ["fitch", "suspected", "Fisher pval", "Gene frac",
             "N. rea", "N. used", "Cod. count", "Sub. count",
             "G. len", "codon_lik", "N. mixte", "id"]  # , 'total_aa']

def memory_used():
    # Bad practice rewriting this method at each time
    mem = 0
    try:
        mem = process.get_memory_info().rss/(1024.0*1024)
    except:
        mem = process.memory_info().rss/(1024.0*1024)
    return mem


def compile_result(x, spec_data, clf, cod_align, model, glimit, fpos, settings):
    """compile result from analysis"""
    fitch, s_complete_data, slist = x
    data = utils.makehash()
    data['genome'] = spec_data
    for spec in slist:
        data['aa'][fitch.ori_aa1][fitch.dest_aa1][spec] = s_complete_data[spec]

    out_res = None
    X_data, X_labels, _ = read_from_json(
        data, None, use_global=settings.USE_GLOBAL)
    # extract usefull features
    if X_data is not None and X_data.size:
        X_data, X_dataprint, selected_et = model.format_data(X_data)
        pred_prob = clf.predict_proba(X_data)
        #pred = clf.predict(X_data)
        pred = pred_prob.argmax(axis=1)
        if sum(pred) == 0 and settings.SKIP_EMPTY:
            return None
        sppval, outdir, rkp, codvalid = utils.get_report(
            fitch, s_complete_data, cod_align, (X_data, X_labels, pred_prob, pred), glimit, fpos, settings)
        utils.print_data_to_txt(os.path.join(outdir, fitch.ori_aa + "_to_" + fitch.dest_aa + "_data.txt"),
                                selected_et, X_dataprint, X_labels, pred, pred_prob, sppval, fitch.dest_aa, valid=codvalid)
        tmp_data = [X_labels, pred, pred_prob, codvalid]
        del X_dataprint
        out_res =(rkp, tmp_data, data['aa'])
    
    del fitch
    del X_data
    del s_complete_data
    _ = gc.collect()
    return out_res


if __name__ == '__main__':

    # argument parser
    parser = argparse.ArgumentParser(
        description='CoreTracker, A codon reassignment tracker')

    parser.add_argument(
        '--wdir', '--outdir', dest="outdir", default="output", help="Working directory")

    parser.add_argument(
        'input', help="Input should be a runnable instance returned by coretracker-prep")

    parser.add_argument('--novalid', dest='valid', action='store_false',
                        help="Do not validate prediction by retranslating and checking alignment improvement")

    parser.add_argument('--expos', '--export_position', dest='expos', action="store_true",
                        help="Export a json file with the position of each reassignment in the corresponding genome.")

    parser.add_argument('--imformat', dest='imformat', choices=('pdf', 'png', 'svg'), default="pdf",
                        help="Image format to use for output (Codon_data file)")

    parser.add_argument('--aapair', dest='aapair',
                        help="Use a list of potential reassignments (see coretracker-prep's '.aa' output for example.")

    parser.add_argument('--parallel', dest='parallel', nargs='?', const=CPU_COUNT, type=int, default=0,
                        help="Use Parallelization during execution for each reassignment. This does not guarantee an increase in speed. CPU count will be used if no argument is provided")

    parser.add_argument('--memory_efficient', action="store_true",
                        dest='mefficient', help="Memory efficient execution")

    parser.add_argument('--version', action='version',
                        version='coretracker-prep v.%s' % __version__)
    parser.add_argument('--debug', dest='debug', action='store_true',
                        help="Enable debug printing")

    print("CoreTracker v:%s Copyright (C) %s %s" %
          (__version__, date, __author__))

    args = parser.parse_args()
    start_t = time.time()

    if args.debug:
        logging.basicConfig(level=logging.DEBUG)
    aasubset = None

    if args.aapair:
        try:
            with open(args.aapair) as f:
                aasubset = yaml.load(f)
                aasubset = [(aa_letters_3to1[k], aa_letters_3to1[vv])
                            for k, v in aasubset.items() for vv in v]
        except:
            print("Input provided with --aapair is not valid !")
            sys.exit(0)

    pools = None
    n_jobs = 1
    if args.parallel > 0:
        pool = mp.Pool(args.parallel)

    r_instance = load(args.input, mmap_mode='r+')
    logging.debug('Instance was read in %.3f s' % (time.time() - start_t))
    clf, model = r_instance.get_model(etiquette)
    nofilter = r_instance.args.nofilter
    reafinder = r_instance.rfinder
    settings = reafinder.settings
    settings.update_params(COMPUTE_POS=args.expos)
    settings.update_params(VALIDATION=args.valid)
    settings.update_params(IMAGE_FORMAT=args.imformat)

    if args.outdir:
        if not os.path.exists(args.outdir):
            os.makedirs(args.outdir)
            # let original error handling
        settings.update_params(OUTDIR=args.outdir)

    if args.mefficient:
        clf.clf.n_jobs = 1
        n_jobs = (args.parallel or 2)

    codon_align, fcodon_align = reafinder.seqset.get_codon_alignment()
    spec_data = reafinder.reassignment_mapper['genome']
    genelimit = reafinder.seqset.gene_limits
    filt_position = reafinder.seqset.filt_position
    cod_align = SeqIO.to_dict(fcodon_align)

    done = False
    results = []
    ALL_PRED = []
    if pool and args.parallel:
        partial_func = partial(compile_result, spec_data=spec_data, clf=clf, cod_align=cod_align, model=model, glimit=genelimit,
                               fpos=filt_position, settings=settings)
        for res in pool.imap_unordered(partial_func, reafinder.run_analysis(codon_align, fcodon_align, aasubset), n_jobs):
            if res is not None:
                results.append(res)

        pool.close()
        pool.join()
        done = True

    if not done:
        for x in reafinder.run_analysis(codon_align, fcodon_align, aasubset):
            results.append(compile_result(
                x, spec_data, clf, cod_align, model, genelimit, filt_position, settings))
        results = [r for r in results if r is not None]

    # unzip data
    results, ALL_PRED, rjson = zip(*results)

    if args.valid and args.expos and results:
        rea_pos_keeper = ddict(dict)
        for r in results:
            for cuspec, readt in r.items():
                for k in readt.keys():
                    rea_pos_keeper[cuspec][k] = readt[k]
        exp_outfile = os.path.join(settings.OUTDIR, "positions.json")
        reafinder.export_position(rea_pos_keeper, exp_outfile)

    reafinder.save_all(ALL_PRED, rjson, True, nofilter=nofilter)

    logging.info("\n**END (%.3f s, %.3f MB)" %
                 (abs(time.time() - start_t),  memory_used()))
