#!/usr/bin/env python

# CoreTracker Copyright (C) 2016  Emmanuel Noutahi
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from Bio import SeqIO
from ete3 import Tree
from functools import partial
from collections import defaultdict as ddict
from collections import namedtuple
from multiprocessing import dummy
from coretracker import __author__, __version__, date
from coretracker.settings import Settings
from coretracker.coreutils import *
from coretracker.coreutils.instance import RunningInstance
import argparse
import glob
import logging
import os
import uuid
import time
import psutil
import yaml

process = psutil.Process(os.getpid())

ENABLE_PAR = True
CPU_COUNT = 0
try:
    from multiprocessing import cpu_count
    CPU_COUNT = cpu_count()
    from joblib import Parallel, delayed
except ImportError:
    try:
        from sklearn.externals.joblib import Parallel, delayed
    except:
        ENABLE_PAR = False
try:
    from yaml import CLoader as Loader
except ImportError:
    from yaml import Loader

SeqParams = namedtuple('SeqParams', [
                       'msaprg', 'use_tree', 'hmmfiles', 'codemap', 'spec_restricted_list', 'scale', 'speciestree'])


def AlignmentProgramType(program):
    # Raise a value error if not conform to expected programs
    tmp = program.strip().split()[0]
    if tmp not in ['muscle', 'mafft']:
        raise argparse.ArgumentTypeError(
            "Accepted alignment programs are mafft and muscle")
    else:
        return program


def memory_used():
    mem = 0
    try:
        mem = process.get_memory_info().rss / (1024.0 * 1024)
    except:
        mem = process.memory_info().rss / (1024.0 * 1024)
    return mem


def param_getter(args):
    """Set all data for coretracker from the argument list"""
    # Check mafft command input
    def progcmd(x): return x + ' --auto' if x == 'mafft' else x
    if args.outdir:
        if not os.path.exists(args.outdir):
            os.makedirs(args.outdir)
    msaprg = progcmd(args.align)
    use_tree = None
    speciestree = Tree(args.tree)
    scale = None
    if args.usetree:
        if args.align != 'mafft':
            logging.warning("Cannot usetree with muscle alignment")
        else:
            use_tree = args.tree
        scale = args.use_tree
    hmmfiles = {}
    if args.hmmdir:
        try:
            hfiles = glob.glob(os.path.join(args.hmmdir, '*'))
            for f in hfiles:
                genename = os.path.basename(f).split('.hmm')[0]
                hmmfiles[genename] = f
        except:
            pass
    codemap = {}
    if args.codemap:
        with open(args.codemap) as CMAP:
            for line in CMAP:
                line = line.strip()
                if line and not line.startswith('#'):
                    spec, code = line.split()
                    codemap[spec] = int(code)
    spec_restricted_list = []
    if args.restrict_to:
        try:
            with open(args.restrict_to) as SREST:
                for line in SREST:
                    line = line.strip()
                    if line and not line.startswith('#'):
                        spec_restricted_list.append(line)
        except IOError:
            spec_restricted_list.append(args.restrict_to)
    return SeqParams(msaprg=msaprg, use_tree=use_tree, hmmfiles=hmmfiles, codemap=codemap, spec_restricted_list=spec_restricted_list, scale=scale, speciestree=speciestree)


def load_seqset(seqpair, settings=None, args=None, seqp=None):
    dnaseq, protseq = seqpair
    seqloader = SequenceLoader(protseq, dnaseq, settings, args.gapfilter, use_tree=seqp.use_tree,
                               refine_alignment=args.refine, msaprog=seqp.msaprg, hmmdict=seqp.hmmfiles)
    setseq = SequenceSet(seqloader, seqp.speciestree,
                         settings.GENETIC_CODE, codemap=seqp.codemap)
    setseq.filter_sequences(args.idfilter, args.gapfilter,
                            args.iccontent, args.rmconst, nofilter=args.nofilter)
    return setseq


if __name__ == '__main__':

    # argument parser
    runid = uuid.uuid1().hex[:5]
    prepare = argparse.ArgumentParser(
        description='coretracker-prep, prepare CoreTracker run for large dataset')

    prepare.add_argument('--version', action='version',
                         version='coretracker-prep v.%s' % __version__)
    prepare.add_argument('--gapfilter', '--gap', type=float, default=0.6, dest='gapfilter',
                         help="Remove position with more than `gapfilter` gap from the alignment, using gapfilter as threshold (default :0.6)")
    prepare.add_argument('--idfilter', '--id', type=float, default=0.5, dest='idfilter',
                         help="Conserve only position with at least `idfilter` residue identity (default : 0.5)")
    prepare.add_argument('--icfilter', '--ic', type=float, default=0.5, dest='iccontent',
                         help="Shannon entropy threshold (default : 0.5 ). This will be used to discard column where IC < max(icvector)*icfilter)")
    prepare.add_argument('-t', '--intree', dest="tree",
                         help='Input species tree in newick format', required=True)
    prepare.add_argument('--protseq', '--prot', '-p', dest='protseq', nargs='+',
                         help="Protein sequence input in core format", required=True)
    prepare.add_argument('--dnaseq', '--dna', '-n', dest='dnaseq', nargs='+',
                         help="Nucleotides sequences input in core format", required=True)
    prepare.add_argument('--rmconst', dest='rmconst', action='store_true',
                         help="Remove constant site from filtered alignment. ")
    prepare.add_argument('--norefine', dest='refine', action='store_false',
                         help="Do not refine the alignment. By default the alignment will be refined. This option should never be used if you have made your own multiple alignment and concatenate it. Else you will have absurd alignment (TO FIX)")
    prepare.add_argument('--nofilter', '--nf', action="store_true", dest="nofilter",
                         help="Do not filter sequence alignment.")
    prepare.add_argument('--align', dest='align', nargs='?', type=AlignmentProgramType,
                         const="muscle", help="Choose a program to align your sequences")
    prepare.add_argument('--use_tree', dest='usetree', nargs='?', const=1.0,
                         help="This is helpfull only if the mafft alignment is selected. Perform multiple alignment, using species tree as guide tree. A scaling value is needed to compute the branch format for the '--treein' option of mafft. If you're unsure, use the default value. The tree must be rooted and binary. See http://mafft.cbrc.jp/alignment/software/treein.html")
    prepare.add_argument('--hmmdir', dest='hmmdir',
                         help="Link a directory with hmm files for alignment. Each hmmfile should be named in the following format : genename.hmm")
    prepare.add_argument('--params', dest='params',
                         help="Use A parameter file to load parameters. If a parameter is not set, the default will be used")
    prepare.add_argument('--parallel', dest='parallel', nargs='?', const=CPU_COUNT, type=int, default=0,
                         help="Use Parallelization during execution for each reassignment. This does not guarantee an increase in speed. CPU count will be used if no argument is provided")
    prepare.add_argument('--restrict_to', dest='restrict_to',
                         help="Restrict analysis to list of genomes only")
    prepare.add_argument('--codemap', dest='codemap',
                         help="A tab delimited file that map each species to its genetic code. Default code (--gcode) will be used for missing species")
    prepare.add_argument('--memory_efficient', action="store_true",
                         help="Will implement a set of decision to compute reduce memory usage (for example threads instead of pools). Can be slower though.")
    prepare.add_argument('--out', '--outfile', dest='outfile',
                         default='core_%s' % runid, help="Output file where to save")
    prepare.add_argument('--wdir', '--outdir', dest="outdir",
                         default="output", help="Working directory")
    prepare.add_argument('--verbose', dest='verbose',
                         action="store_true", help="Verbose printing")

    args = prepare.parse_args()
    if args.verbose:
        logging.basicConfig(level=logging.INFO)

    logging.info("CoreTracker v:%s Copyright (C) %s %s" %
                 (__version__, date, __author__))

    logging.info("Setting running parameters...")

    start_t = cur_t = time.time()
    if args.nofilter:
        # set the filter parameters to none
        args.idfilter = None
        args.gapfilter = None
        args.iccontent = None

    if len(args.dnaseq) != len(args.protseq):
        raise argparse.ArgumentError(
            "Nucleotide and protein sequence should match")

    dna_prot_pair = zip(args.dnaseq, args.protseq)
    setting = Settings()
    paramfile = args.params
    ad_params = {}
    if paramfile:
        with open(paramfile) as f:
            ad_params = yaml.load(f, Loader=Loader)

    seqp = param_getter(args)
    if seqp.scale:
        setting.SCALE = scale
    setting.fill(ad_params)
    setting.update_params(OUTDIR=args.outdir)

    if seqp.spec_restricted_list:
        setting.LIMIT_TO_SUSPECTED_SPECIES = True
    seqset_loader = partial(
        load_seqset, settings=setting, args=args, seqp=seqp)

    done = False
    seqsets = []
    if args.parallel > 0:
        if args.memory_efficient:
            pool = dummy.Pool(args.parallel)
            seqsets = pool.map(seqset_loader, dna_prot_pair)
            pool.close()
            pool.join()
            done = True

        elif ENABLE_PAR:
            seqsets = Parallel(n_jobs=args.parallel, verbose=int(args.verbose))(
                delayed(seqset_loader)(x) for x in dna_prot_pair)
            done = True
        else:
            logging.warning(
                "Joblib requirement not found! Disabling parallelization")
    if not done:
        for dnaseq, protseq in dna_prot_pair:
            seqsets.append(seqset_loader((dnaseq, protseq)))

    logging.info("Filtered sequences and constructed datasets (%.3f s, using %.3f MB)" % (
        abs(time.time() - cur_t), memory_used()))

    cur_t = time.time()
    run_instance = RunningInstance.from_seqsets(
        seqsets, seqp.spec_restricted_list, setting, args)

    logging.info("Found potential reassignment (%.3f s, using %.3f MB)" % (
        abs(time.time() - cur_t), memory_used()))
    # now save and reload then check again
    outputfile = os.path.join(setting.OUTDIR, args.outfile)
    run_instance.save_instance(outfile=outputfile)
    logging.info("Done! Running instance is saved at %s.\nYou can now use coretracker-run !\nTotal time: %.3f s, using %.3f MB" %
                 (outputfile + ".pgz", abs(time.time() - start_t), memory_used()))
