#!python

"""Comprehensive model comparison and selection-detection with ``phydms``.

Written by Jesse Bloom and Sarah Hilton."""


import sys
import os
import re
import time
import copy
import logging
import multiprocessing
import subprocess
import signal
import glob
import phydmslib
import phydmslib.file_io
import phydmslib.parsearguments
import pandas as pd

def RunCmds(cmds):
    """Runs the command line arguments in *cmds* using *subprocess*."""
    try:
        p = subprocess.Popen(cmds, stdout=subprocess.PIPE,
                stderr=subprocess.STDOUT)
        pid = p.pid
        (stdout, stderr) = p.communicate()
    except:
        os.kill(pid, signal.SIGTERM)

def createModelComparisonFile(models, outprefix, modelcomparisonfile):
    """Creates `modelcomparisonfile` in Markdown format.

    `models` is a list of the prefixes used for each run.

    *modelcomparison.csv*
        Columns include 'model', 'variable', 'value'
        'variable' includes all of the optimized parameters, the final log
        likelihood and the number of optimized parameters.
    *modelcomparison.md*
        Columns include 'model', 'Loglikelihood', 'deltaAIC', and all non-phi
        optimized parameters.
    """

    # modelcomparison.csv
    columns = ['Model', 'deltaAIC', 'LogLikelihood', 'nParams', 'ParamValues']
    stats = dict([(col, []) for col in columns])
    for modelName in models:
        stats['Model'].append(modelName)
        with open(outprefix + modelName + '_loglikelihood.txt') as f:
            stats['LogLikelihood'].append(float(f.read().split()[-1]))
        with open(outprefix + modelName + '_modelparams.txt') as f:
            params = [(line.split('=')) for line in f]
        stats['nParams'].append(len(params))
        paramvalues = ['{0}={1:.2f}'.format(name.strip(), float(value)) for
                (name, value) in sorted(params) if 'phi' not in name]
        stats['ParamValues'].append(', '.join(paramvalues))
        stats['deltaAIC'].append(2 * (stats['nParams'][-1] -
                stats['LogLikelihood'][-1]))
    minAIC = min(stats['deltaAIC'])
    stats['deltaAIC'] = [x - minAIC for x in stats['deltaAIC']]
    nrows = len(stats['Model'])
    indexorder = [tup[1] for tup in sorted(zip(stats['deltaAIC'], range(nrows)))]
    for col in ['deltaAIC', 'LogLikelihood']:
        stats[col] = ['{0:.2f}'.format(x) for x in stats[col]]
    stats['nParams'] = [str(x) for x in stats['nParams']]
    colwidths = [max(map(len, stats[col] + [col])) for col in columns]
    formatline = ('| ' + ' | '.join(['{' + str(i) + ':' + str(w) + '}'
            for (i, w) in enumerate(colwidths)]) + ' |')
    sepline = '|-' + '-|-'.join(['-' * w for w in colwidths]) + '-|'
    text = [formatline.format(*columns), sepline]
    for i in indexorder:
        line = [stats[col][i] for col in columns]
        text.append(formatline.format(*line))
    with open(modelcomparisonfile, 'w') as f:
        f.write('\n'.join(text))


def main():
    """Main body of script."""

    # Parse command line arguments
    parser = phydmslib.parsearguments.PhyDMSComprehensiveParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    #create output directory if needed
    outdir = os.path.dirname(args['outprefix'])
    if outdir:
        if not os.path.isdir(outdir):
            if os.path.isfile(outdir):
                os.remove(outdir)
            os.mkdir(outdir)

    #setup files
    # file names slightly different depending on
    # whether outprefix is directory or file
    if args['outprefix'][-1] == '/':
        logfile = "{0}log.log".format(args['outprefix'])
    else:
        logfile = "{0}.log".format(args['outprefix'])
        args['outprefix'] = '{0}_'.format(args['outprefix'])
    modelcomparisonfile = '{0}modelcomparison.md'.format(args['outprefix'])
    raxmlStandardOutputFile = '{0}raxml_output.txt'.format(args['outprefix'])

    #Set up to log everything to logfile.
    if os.path.isfile(logfile):
        os.remove(logfile)
    logging.shutdown()
    logging.captureWarnings(True)
    versionstring = phydmslib.file_io.Versions()
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
        level=logging.INFO)
    logger = logging.getLogger(prog)
    logfile_handler = logging.FileHandler(logfile)
    logger.addHandler(logfile_handler)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    logfile_handler.setFormatter(formatter)

    #print some basic information
    logger.info('Beginning execution of {0} in directory {1}\n'
            .format(prog, os.getcwd()))
    logger.info('Progress is being logged to {0}\n'.format(logfile))
    logger.info('{0}\n'.format(versionstring))
    logger.info('Parsed the following command-line arguments:\n{0}\n'
            .format('\n'.join(['\t{0} = {1}'.format(key, args[key])
            for key in args.keys()])))

    #setup models
    filesuffixlist = ['_log.log', '_tree.newick', '_loglikelihood.txt',
            '_modelparams.txt']
    filesuffixes = {} # keyed by model, values are list of suffixes
    models = {}
    additionalcmds = ["--brlen", args["brlen"]]
    if args['omegabysite']:
        additionalcmds.append('--omegabysite')
        filesuffixlist = copy.deepcopy(filesuffixlist) + ['_omegabysite.txt']

    #set up the YNGKP models
    models = {'YNGKP_M0':('YNGKP_M0', additionalcmds)}
    filesuffixes['YNGKP_M0'] = filesuffixlist
    models['YNGKP_M5'] = ('YNGKP_M5', additionalcmds)
    filesuffixes['YNGKP_M5'] = filesuffixlist

    #set up the ExpCM
    additionalcmds = copy.deepcopy(additionalcmds)
    if args['diffprefsbysite']:
        additionalcmds = copy.deepcopy(additionalcmds) + ['--diffprefsbysite']
        filesuffixlist = copy.deepcopy(filesuffixlist) + ['_diffprefsbysite.txt']
    for prefsfile in args['prefsfiles']:
        if re.search('\s', prefsfile):
            raise ValueError("There is a space in the preferences file name:\
                    {0}".format(prefsfile))
        prefsfilebase = os.path.splitext(os.path.basename(prefsfile))[0]
        modelname = 'ExpCM_{0}'.format(prefsfilebase)
        assert modelname not in filesuffixes, "Duplicate preferences file"\
                " base name {0} for {1}; make names unique even after "\
                "removing directory and extension".format(modelname, prefsfile)
        filesuffixes[modelname] = filesuffixlist
        models[modelname] = ('ExpCM_{0}'.format(prefsfile), additionalcmds)
        if args["gammaomega"]:
            gammaomegamodelname = '{0}_gammaomega'.format(modelname)
            models[gammaomegamodelname] = ('ExpCM_{0}'.format(prefsfile),
                    additionalcmds + ['--gammaomega'])
            filesuffixes[gammaomegamodelname] = filesuffixlist
        if args["gammabeta"]:
            gammabetamodelname = '{0}_gammabeta'.format(modelname)
            models[gammabetamodelname] = ('ExpCM_{0}'.format(prefsfile),
                    additionalcmds + ['--gammabeta'])
            filesuffixes[gammabetamodelname] = filesuffixlist
        if not args['noavgprefs']:
            avgmodelname = 'averaged_{0}'.format(modelname)
            models[avgmodelname] = ('ExpCM_{0}'.format(prefsfile),
                    additionalcmds + ['--avgprefs'], 0)
            filesuffixes[avgmodelname] = filesuffixlist
            if args["gammaomega"]:
                avgmodelname = 'averaged_{0}_gammaomega'.format(modelname)
                models[avgmodelname] = ('ExpCM_{0}'.format(prefsfile),
                        additionalcmds + ['--gammaomega'] \
                        + ['--avgprefs'], 0)
                filesuffixes[avgmodelname] = filesuffixlist
            if args["gammabeta"]:
                avgmodelname = 'averaged_{0}_gammabeta'.format(modelname)
                models[avgmodelname] = ('ExpCM_{0}'.format(prefsfile),
                        additionalcmds + ['--gammabeta'] \
                        + ['--avgprefs'], 0)
                filesuffixes[avgmodelname] = filesuffixlist
        if args['randprefs']:
            randmodelname = 'randomized_{0}'.format(modelname)
            models[randmodelname] = ('ExpCM_{0}'.format(prefsfile),
                    additionalcmds + ['--randprefs'], 0)
            filesuffixes[randmodelname] = filesuffixlist
            if args["gammaomega"]:
                randmodelname = 'randomized_{0}_gammaomega'.format(modelname)
                models[randmodelname] = ('ExpCM_{0}'.format(prefsfile),
                        additionalcmds + ['--gammaomega'] \
                        + ['--randprefs'], 0)
                filesuffixes[randmodelname] = filesuffixlist
            if args["gammabeta"]:
                randmodelname = 'randomized_{0}_gammabeta'.format(modelname)
                models[randmodelname] = ('ExpCM_{0}'.format(prefsfile),
                        additionalcmds + ['--gammabeta'] \
                        + ['--randprefs'], 0)
                filesuffixes[randmodelname] = filesuffixlist

    # check alignment
    logger.info('Checking that the alignment {0} is valid...'
            .format(args['alignment']))
    alignment = phydmslib.file_io.ReadCodonAlignment(args['alignment'],
            checknewickvalid=True) #checks that the identifiers are unique
    logger.info('Valid alignment specifying {0} sequences of length {1}.\n'
            .format(len(alignment), len(alignment[0][1])))

    #read or build a tree
    if args["tree"]:
      logger.info("Reading tree from {0}".format(args['tree']))
    else:
        logger.info("Tree not specified.")
        try:
            raxmlversion = subprocess.check_output([args['raxml'], '-v'],
                    stderr=subprocess.STDOUT).strip()
            logger.info("Inferring tree with RAxML using command {0}"
                    .format(args['raxml']))

            #remove pre-existing RAxML files
            raxmlOutputName = os.path.splitext(os.path.basename(
                    args["alignment"]))[0]
            raxmlOutputFiles = []
            for raxmlFile in glob.glob("RAxML_*{0}".format(raxmlOutputName)):
                if os.path.isfile(raxmlFile):
                    raxmlOutputFiles.append(raxmlFile)
                    os.remove(raxmlFile)
            if len(raxmlOutputFiles) > 0:
                logger.info('Removed the following RAxML files:\n{0}\n'
                    .format('\n'.join(['\t{0}'.format(fname) for
                    fname in raxmlOutputFiles])))

            # run RAxmL
            raxmlCMD = [args['raxml'], '-s', args['alignment'], '-n',
                    raxmlOutputName, '-m', 'GTRCAT', '-p1', '-T', '2']
            with open(raxmlStandardOutputFile, 'w') as f:
                subprocess.check_call(raxmlCMD, stdout = f)

            # move RAxML tree to output directory and remove all other files
            for raxmlFile in glob.glob("RAxML_*{0}".format(raxmlOutputName)):
                if "bestTree" in raxmlFile:
                    args["tree"] = args["outprefix"] + 'RAxML_tree.newick'
                    os.rename(raxmlFile, args['tree'])
                    logger.info("RAxML inferred tree is now named {0}".format(
                            args['tree']))
                else:
                    os.remove(raxmlFile)
        except OSError:
            raise ValueError("The raxml command of {0} is not valid."
                    " Is raxml installed at this path?".format(args['raxml']))

    # get number of available CPUs and assign to each model
    if args['ncpus'] == -1:
        try:
            args['ncpus'] = multiprocessing.cpu_count()
        except:
            raise RuntimeError("Encountered a problem trying to dynamically"
                    " determine the number of available CPUs. "
                    "Please manually specify the number of desired CPUs with"
                    " '--ncpus' and try again.")
            logger.info('Will use all %d available CPUs.\n' % args['ncpus'])
    assert args['ncpus'] >= 1, "Failed to specify valid number of CPUs"

    # YNGKP models get one CPU
    # ExpCM get more than one if excess over number of models
    expcm_modelnames = [modelname for modelname in models.keys() \
            if 'ExpCM' in modelname]
    yngkp_modelnames = [modelname for modelname in models.keys() \
            if 'YNGKP' in modelname]
    assert len(models.keys()) == len(expcm_modelnames) + len(yngkp_modelnames)\
            , "not ExpCM or YNGKP:\n{0}".format(str(models.keys()))
    ncpus_per_model = dict([(modelname, 1) for modelname in yngkp_modelnames])
    nperexpcm = max(1, (args['ncpus'] - 2) // len(expcm_modelnames))
    for modelname in expcm_modelnames:
        ncpus_per_model[modelname] = nperexpcm
    for modelname in models.keys():
        mtup = models[modelname]
        models[modelname] = (mtup[0], mtup[1] + ['--ncpus',
                str(ncpus_per_model[modelname])])


    pool = {} # holds process for model name
    started = {} # holds whether process started for model name
    completed = {} # holds whether process completed for model name
    outprefixes = {} # holds outprefix for model name

    # rest of execution in try / finally
    try:

        # remove existing output files
        outfiles = [modelcomparisonfile]
        removed = []
        for modelname in models.keys():
            for suffix in filesuffixes[modelname]:
                fname = "{0}{1}{2}".format(args['outprefix'], modelname, suffix)
                outfiles.append(fname)
        for fname in outfiles:
            if os.path.isfile(fname):
                os.remove(fname)
                removed.append(fname)
        if removed:
            logger.info('Removed the following existing files that have names'
                    " that match the names of output files that will be "
                    'created: {0}\n'.format(', '.join(removed)))

        #now run the models
        for modelname in list(models.keys()):
            (model, additionalcmds) = models[modelname]
            outprefix = "{0}{1}".format(args['outprefix'], modelname)
            cmds = ['phydms', args['alignment'], args["tree"], model,
                    outprefix] + additionalcmds
            logger.info('Starting analysis to optimize tree in {0} using model '
                    '{1}. The command is: {2}\n'.format( args["tree"],
                    modelname, ' '.join(cmds)))
            pool[modelname] = multiprocessing.Process(target=RunCmds\
                    , args=(cmds,))
            outprefixes[modelname] = outprefix
            completed[modelname] = False
            started[modelname] = False
        while not all(completed.values()):
            nrunning = list(started.values()).count(True) - \
                    list(completed.values()).count(True)
            if nrunning < args['ncpus']:
                for (modelname, p) in pool.items():
                    if not started[modelname]:
                        p.start()
                        started[modelname] = True
                        break
            for (modelname, p) in pool.items():
                if started[modelname] and (not completed[modelname]) and \
                        (not p.is_alive()): # process just completed
                    completed[modelname] = True
                    logger.info('Analysis completed for {0}'.format(modelname))
                    for fname in [outprefixes[modelname] + suffix for suffix
                            in filesuffixes[modelname]]:
                        if not os.path.isfile(fname):
                            raise RuntimeError("phydms failed to created"
                                    " expected output file {0}.".format(fname))
                        logger.info("Found expected output file {0}"\
                                .format(fname))
                    logger.info('Analysis successful for {0}\n'\
                            .format(modelname))
            time.sleep(1)

        createModelComparisonFile(models.keys(), args["outprefix"],
                modelcomparisonfile)

        # make sure all expected output files are there
        for fname in outfiles:
            if not os.path.isfile(fname):
                raise RuntimeError("Cannot find expected output file {0}"
                        .format(fname))

    except:
        logger.exception('Terminating {0} at {1} with ERROR'
                .format(prog, time.asctime()))
    else:
        logger.info('Successful completion of {0}'.format(prog))
    finally:
        logging.shutdown()
        for p in pool.values():
            if p.is_alive():
                p.terminate()


if __name__ == '__main__':
    main() # run the script
