#!python

"""Test different diversifying pressure models with ``phydms``.

Written by Jesse Boom and Sarah Hilton.
"""


import sys
import os
import re
import time
import logging
import multiprocessing
import subprocess
import glob
import phydmslib
import phydmslib.file_io
import phydmslib.parsearguments
import pandas as pd
import random

def randomizeDivpressure(divpressureFile,numberRandomizations,outprefix):
    """
    This function randomizes a diversfiying pressure list the number of
    times specfied by the user. Each randomization is written to a new
    diversifying pressures files in the directory 'randomizedFiles'.
    The function returns a dictionary of file names keyed by the random seed.
    """
    divpressurefilebase = os.path.splitext(os.path.basename(divpressureFile))[0]
    outdir = os.path.dirname(outprefix+"randomizedFiles/")
    if outdir:
        if not os.path.isdir(outdir):
            if os.path.isfile(outdir):
                os.remove(outdir)
            os.mkdir(outdir)
    randomFiles = {}
    divpressureValues = phydmslib.file_io.readDivPressure(divpressureFile)
    sites = list(divpressureValues.keys())
    divpressures = list(divpressureValues.values())
    for seed in range(int(numberRandomizations)):
        random.seed(seed)
        random.shuffle(divpressures)
        randomDivpressureFile = outdir+'/{0}_random_{1}.csv'\
                .format(divpressurefilebase, seed)
        with open(randomDivpressureFile, 'w') as f:
            f.write("#SITE,VALUE\n")
            f.write('\n'.join('{0},{1}'.format(site, dp) for (site, dp) \
                    in zip(sites, divpressures)))
        randomFiles[seed] = randomDivpressureFile
    return randomFiles

def createModelComparisonFile(models, outprefix):
    """
    Creates two summary files, a .csv and a .md

    modelcomparison.csv
        *Preferences*
            The prefs file used for `ExpCM`
        *Diversifying Pressure Set*
            The diversifying pressure file
        *Diversifying Pressure Type*
            "True" if the values in the Diversifying Pressure Set
            "None" if `ExpCM` was run without diversifying pressure
            "Random" if the true values were randomized
        *Diversifying Pressure ID*
            Combination of the Diversifying Pressure Set and the Diversifying
            Pressure Type. Unique identifier for each run of `phydms`
        *variable*
            All of the optimized parameters and the final log likelihood
        *Value*
            The value of the variable
    *modelcomparison.md*
        Columns include 'Diversifying Pressure Set', 'Loglikelihood',
        and all optimized parameters.
        Only the non-randomized `phydms` runs are included.
    """

    #modelcomparison.csv
    final = pd.DataFrame({"Preferences":[],"DiversifyingPressureSet":[],
            "DiversifyingPressureType":[],"DiversifyingPressureID":[],
            "variable":[], "value":[]})
    for modelID in models.keys():
        outFilePrefix = "ExpCM_" + '_'.join(str(item) for item in modelID\
                if item != "None")
        df = pd.read_csv(outprefix + outFilePrefix + "_modelparams.txt",
                sep = " = ", header = None, engine = 'python')
        df.columns = ['variable', 'value'] #rename the columns
        df["Preferences"] = [modelID[0] for x in range(len(df))]
        df["DiversifyingPressureSet"] = [modelID[1] for x in range(len(df))]
        df["DiversifyingPressureType"] = [modelID[2] for x in range(len(df))]
        df["DiversifyingPressureID"] = [modelID[1] + "_" + str(modelID[3]) \
                for x in range(len(df))]
        final = pd.concat([final,df])
        with open(outprefix + outFilePrefix + "_loglikelihood.txt") as f:
            temp = {"Preferences":[],"DiversifyingPressureSet":[],
                    "DiversifyingPressureType":[],"DiversifyingPressureID":[],
                    "variable":[], "value":[]}
            lines = f.readlines()
            temp["Preferences"] = [modelID[0]]
            temp["DiversifyingPressureSet"] = [modelID[1]]
            temp["DiversifyingPressureType"] = [modelID[2]]
            temp["DiversifyingPressureID"] = [modelID[1] + "_" + str(modelID[3])]
            temp["variable"].append("LogLikelihood")
            temp["value"].append(lines[0].split(" = ")[-1])
        final = pd.concat([final, pd.DataFrame(temp)])
    final.to_csv(outprefix + "modelcomparison.csv", index = False)
    final.reset_index(drop=True, inplace = True)


    #modelcomparison.md
    #pivot the dataframe so the parameters are columns
    mdDF = final[final["DiversifyingPressureType"] != "Random"]\
            .pivot(index = "DiversifyingPressureSet", columns='variable',
             values='value')
    mdDF = mdDF[[x for x in mdDF.columns.values]]
    with open(outprefix + "modelcomparison.md", "w") as f:
        f.write("|".join(["DiversifyingPressureSet"] \
                + list(mdDF.columns.values)) + "\n")
        f.write("".join(["---|" for x in range(len(mdDF.columns.values))]
                + ["---"]) + "\n")
        for index, row in mdDF.iterrows():
            f.write("|".join([str(x) for x in [index]+ list(row)]) + "\n")

def RunCmds(cmds):
    """Runs the command line arguments in *cmds* using *subprocess*."""
    try:
        p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
        pid = p.pid
        (stdout, stderr) = p.communicate()
    except:
        os.kill(pid, signal.SIGTERM)

def main():
    """
    Main body of script.

    *modelID* tuple
        Each run of `phydms` is considered one model by this script.
        Each model has unique ID which describes the preferences and the
        diversifying pressures used in the `ExpCM`.
        The ID is a tuple of the following form:
            (pref file, diversifying pressure (name of file or "None"),
            diversifying pressure type ("None", "True", or "Random"), and the
            random seend (integer or "None"))
    *models* dictionary keyed by modelID
        each value is a tuple of the following form:
            (pref file, list of `phydms` commands,
            number of empirical parameters, modelprefix)

    The branch lengths are optimized using `ExpCM` without diversifying pressure
    on either the tree specfied by the user or the tree inferred under the
    *GTRCAT* with `RAxML`. The `ExpCM` with diversifying pressures are run with
    this optimized tree and the branch lengths are scaled but not individually
    optimized.
    """

    # Parse command line arguments
    parser = phydmslib.parsearguments.PhyDMSTestdivpressureParser()
    args = vars(parser.parse_args())
    prog = parser.prog

    #create output directory if needed
    outdir = os.path.dirname(args['outprefix'])
    if outdir:
        if not os.path.isdir(outdir):
            if os.path.isfile(outdir):
                os.remove(outdir)
            os.mkdir(outdir)

    #setup files
    #file names depend on whether outprefix is directory or file
    if args['outprefix'][-1] == '/':
        logfile = "{0}log.log".format(args['outprefix'])
    else:
        logfile = "{0}.log".format(args['outprefix'])
        args['outprefix'] = '{0}_'.format(args['outprefix'])
    modelcomparisonfile = '{0}modelcomparison.txt'.format(args['outprefix'])
    raxmlStandardOutputFile = '{0}ramxl_output.txt'.format(args['outprefix'])

    #Set up to log everything to logfile.
        # Set up to log everything to logfile.
    if os.path.isfile(logfile):
        os.remove(logfile)
    logging.shutdown()
    logging.captureWarnings(True)
    versionstring = phydmslib.file_io.Versions()
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
        level=logging.INFO)
    logger = logging.getLogger(prog)
    logfile_handler = logging.FileHandler(logfile)
    logger.addHandler(logfile_handler)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    logfile_handler.setFormatter(formatter)

    #print some basic information
    logger.info('Beginning execution of {0} in directory {1}\n'
            .format(prog, os.getcwd()))
    logger.info('Progress is being logged to {0}\n'.format(logfile))
    logger.info('{0}\n'.format(versionstring))
    logger.info('Parsed the following command-line arguments:\n{0}\n'
            .format('\n'.join(['\t{0} = {1}'.format(key, args[key])
            for key in args.keys()])))

    #Set up to log everything to logfile.
    if os.path.isfile(logfile):
        os.remove(logfile)
    logging.shutdown()
    versionstring = phydmslib.file_io.Versions()
    logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
            level=logging.INFO)
    logger = logging.getLogger(prog)
    logfile_handler = logging.FileHandler(logfile)
    logger.addHandler(logfile_handler)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    logfile_handler.setFormatter(formatter)

    #print some basic information
    logger.info('Beginning execution of {0} in directory {1}\n'
            .format(prog, os.getcwd()))
    logger.info('Progress is being logged to {0}\n'.format(logfile))
    logger.info('{0}\n'.format(versionstring))
    logger.info('Parsed the following command-line arguments:\n{0}\n'
            .format('\n'.join(['\t{0} = {1}'.format(key, args[key])
            for key in args.keys()])))

    #setup models
    filesuffixlist = ['_log.log', '_tree.newick', '_loglikelihood.txt',
            '_modelparams.txt']
    filesuffixes = {} # keyed by model, values are list of suffixes
    models = {}

    #set up the ExpCM without diversifying pressure
    expcmadditionalcmds = []
    nempiricalExpCM = 3
    if re.search('\s', args['prefsfile']):
        raise ValueError("There is a space in the preferences file name:\
                {0}".format(args['prefsfile']))
    prefsfilebase = os.path.splitext(os.path.basename(args['prefsfile']))[0]
    modelID = (prefsfilebase, "None", "None", "None")
    assert modelID not in filesuffixes, "Duplicate model ID"
    filesuffixes[modelID] = filesuffixlist
    models[modelID] = ('ExpCM_{0}'.format(args['prefsfile']),
            expcmadditionalcmds, nempiricalExpCM, "ExpCM_" + '_'.join(str(item)\
                    for item in modelID if item != "None"))

    #set up ExpCM with diversifying pressure
    for divpressure in args["divpressure"]:
        divpressurefilebase = os.path.splitext(os.path.basename(divpressure))[0]
        modelID = (prefsfilebase, divpressurefilebase, "True", "None")
        assert modelID not in filesuffixes, "Duplicate model ID (divpressure)"
        filesuffixes[modelID] = filesuffixlist
        models[modelID] = ('ExpCM_{0}'.format(args['prefsfile']),
                expcmadditionalcmds + ["--divpressure", divpressure],
                nempiricalExpCM, "ExpCM_" + '_'.join(str(item) for item\
                in modelID if item != "None"))
        #set up ExpCM with diversifying pressure randomizations
        if args["randomizations"]:
            randomFiles = randomizeDivpressure(divpressure,
                    args["randomizations"], args["outprefix"])
            for random in randomFiles.keys():
                modelID = (prefsfilebase, divpressurefilebase, "Random", random)
                assert modelID not in filesuffixes, "Duplicate model ID (divpressure)"
                filesuffixes[modelID] = filesuffixlist
                models[modelID] = ('ExpCM_{0}'.format(args['prefsfile']),
                        expcmadditionalcmds + ["--divpressure", randomFiles[random]],
                        nempiricalExpCM, "ExpCM_" + '_'.join(str(item) for item\
                        in modelID if item != "None"))

    # check alignment
    logger.info('Checking that the alignment {0} is valid...'
            .format(args['alignment']))
    alignment = phydmslib.file_io.ReadCodonAlignment(args['alignment'],
            checknewickvalid=True) #checks that the identifiers are unique
    assert len(set([align[1] for align in alignment])) == len([align[1]
            for align in alignment]), "Remove duplicate sequences from {0}."\
                    .format(args["alignment"])
    logger.info('Valid alignment specifying {0} sequences of length {1}.\n'
            .format(len(alignment), len(alignment[0][1])))


    #read or build a tree
    if args["tree"]:
      logger.info("Reading tree from {0}".format(args['tree']))
    else:
        logger.info("Tree not specified.")
        try:
            raxmlversion = subprocess.check_output([args['raxml'], '-v'],
                    stderr=subprocess.STDOUT).strip()
            logger.info("Inferring tree with RAxML using command {0}"
                    .format(args['raxml']))

            #remove pre-existing RAxML files
            raxmlOutputName = os.path.splitext(os.path.basename(
                    args["alignment"]))[0]
            raxmlOutputFiles = [n for n in glob.glob("RAxML_*{0}"
                    .format(raxmlOutputName)) if os.path.isfile(n)]
            raxmlOutputFiles = []
            for raxmlFile in glob.glob("RAxML_*{0}".format(raxmlOutputName)):
                if os.path.isfile(raxmlFile):
                    raxmlOutputFiles.append(raxmlFile)
                    os.remove(raxmlFile)
            if len(raxmlOutputFiles) > 0:
                logger.info('Removed the following RAxML files:\n{0}\n'
                    .format('\n'.join(['\t{0}'.format(fname) for
                    fname in raxmlOutputFiles])))

            # run RAxmL
            raxmlCMD = [args['raxml'], '-s', args['alignment'], '-n',
                    raxmlOutputName, '-m', 'GTRCAT', '-p1', '-T', '2']
            with open(raxmlStandardOutputFile, 'w') as f:
                subprocess.check_call(raxmlCMD, stdout = f)

            # move RAxML tree to output directory and remove all other files
            for raxmlFile in glob.glob("RAxML_*{0}".format(raxmlOutputName)):
                if "bestTree" in raxmlFile:
                    args["tree"] = args["outprefix"] + 'RAxML_tree.newick'
                    os.rename(raxmlFile, args['tree'])
                    logger.info("RAxML inferred tree is now named {0}".format(
                            args['tree']))
                else:
                    os.remove(raxmlFile)
        except OSError:
            raise ValueError("The raxml command of {0} is not valid."
                    " Is raxml installed at this path?".format(args['raxml']))

    #setting up cpus
    if args['ncpus'] == -1:
        try:
            args['ncpus'] = multiprocessing.cpu_count()
        except:
            raise RuntimeError("Encountered a problem trying to dynamically"
                    "determine the number of available CPUs. Please manually"
                    " specify the number of desired CPUs with '--ncpus' and"
                    " try again.")
        logger.info('Will use all %d available CPUs.\n' % args['ncpus'])
    assert args['ncpus'] >= 1, "Failed to specify valid number of CPUs"

    #Each model gets at least 1 CPU
    ncpus_per_model = {}
    nperexpcm = max(1, (args['ncpus'] - 2) // len(models))
    for modelID in models.keys():
        ncpus_per_model[modelID] = nperexpcm
        mtup = models[modelID]
        assert len(mtup) == 4 # should be 3-tuple
        models[modelID] = (mtup[0], mtup[1] + ['--ncpus',
                str(ncpus_per_model[modelID])], mtup[2], mtup[3])

    pool = {} # holds process for model name
    started = {} # holds whether process started for model name
    completed = {} # holds whether process completed for model name
    outprefixes = {} # holds outprefix for model name

    # rest of execution in try / finally
    try:

        # remove existing output files
        outfiles = []
        removed = []
        for modelID in models.keys():
            for suffix in filesuffixes[modelID]:
                fname = "{0}{1}{2}".format(args['outprefix'], models[modelID][3], suffix)
                outfiles.append(fname)
        for fname in outfiles:
            if os.path.isfile(fname):
                os.remove(fname)
                removed.append(fname)
        if removed:
            logger.info('Removed the following existing files that have names'
                    " that match the names of output files that will be "
                    'created: {0}\n'.format(', '.join(removed)))

    # first run the `ExpCM` model without diversiying pressure to optimize
    #the branch lengths
        noDivPressure = [x for x in list(models.keys()) if x[1] == "None"]
        assert len(noDivPressure) == 1, "More than one ExpCM without\
                diversifying pressure was specified"
        noDivPressure = noDivPressure[0]
        (model, additionalcmds, nempirical, modelname) = models[noDivPressure]
        outprefix = "{0}{1}".format(args['outprefix'], modelname)
        cmds = ['phydms', args['alignment'], args["tree"], model,
                outprefix] + additionalcmds + ["--brlen", "optimize"]
        logger.info('Starting analysis. Optimizing the branch lengths. The command is: %s\n' % (' '.join(cmds)))
        outprefixes[noDivPressure] = outprefix
        subprocessOutput = '{0}subprocess_output.txt'.format(args['outprefix'])
        with open(subprocessOutput, 'w') as f:
            subprocess.check_call(cmds, stdout = f)
        if os.path.isfile(subprocessOutput):
            os.remove(subprocessOutput)
        for fname in [outprefixes[noDivPressure] + suffix for suffix
                in filesuffixes[noDivPressure]]:
            if not os.path.isfile(fname):
                raise RuntimeError("phydms failed to created"
                        " expected output file {0}.".format(fname))
        logger.info('Analysis successful for {0} without diversifying pressure.\n'\
                .format(modelID[0]))
        treeFile = '{0}_tree.newick'.format(outprefix)
        assert os.path.isfile(treeFile)

    #now run the other models
        divPressureModels = [x for x in list(models.keys()) if x[1] != "None"]
        for modelID in divPressureModels:
            (model, additionalcmds, nempirical, modelname) = models[modelID]
            outprefix = "{0}{1}".format(args['outprefix'], modelname)
            cmds = ['phydms', args['alignment'], treeFile, model,
                    outprefix] + additionalcmds + ["--brlen", "scale"]
            pool[modelID] = multiprocessing.Process(target=RunCmds\
                    , args=(cmds,))
            outprefixes[modelID] = outprefix
            completed[modelID] = False
            started[modelID] = False
        while not all(completed.values()):
            nrunning = list(started.values()).count(True) - \
                    list(completed.values()).count(True)
            if nrunning < args['ncpus']:
                for (modelID, p) in pool.items():
                    if not started[modelID]:
                        p.start()
                        started[modelID] = True
                        break
            for (modelID, p) in pool.items():
                if started[modelID] and (not completed[modelID]) and \
                        (not p.is_alive()): # process just completed
                    completed[modelID] = True
                    for fname in [outprefixes[modelID] + suffix for suffix
                            in filesuffixes[modelID]]:
                        if not os.path.isfile(fname):
                            raise RuntimeError("phydms failed to created"
                                    " expected output file {0}.".format(fname))
                    if modelID[1] != "None":
                        if modelID[2] != "True":
                            logger.info('Analysis successful for {0} with diversifying pressure {1} and random seed {2}\n'\
                                .format(modelID[0], modelID[1], modelID[3]))
                        else:
                            logger.info('Analysis successful for {0} with diversifying pressure {1}\n'\
                                .format(modelID[0], modelID[1]))
                    else:
                        logger.info('Analysis successful for {0} without diversifying pressure.\n'\
                            .format(modelID[0]))
            time.sleep(1)

    #make sure all expected output files are there
        for fname in outfiles:
            if not os.path.isfile(fname):
                raise RuntimeError("Cannot find expected output file {0}"
                        .format(fname))
        createModelComparisonFile(models, args["outprefix"])
        for fname in outfiles:
            if os.path.isfile(fname):
                os.remove(fname)
        if args["randomizations"]:
            for f in glob.glob(args["outprefix"] + "randomizedFiles/*"):
                os.remove(f)
            os.rmdir(args["outprefix"] + "randomizedFiles/")

    except:
        logger.exception('Terminating {0} at {1} with ERROR'
                .format(prog, time.asctime()))
    else:
        logger.info('Successful completion of {0}'.format(prog))
    finally:
        logging.shutdown()
        for p in pool.values():
            if p.is_alive():
                p.terminate()


if __name__ == '__main__':
    main() # run the script

#make the test script
#figure out why the logging prints twice
