#!/home/travis/build/pycbc-build/environment/bin/python

# Copyright (C) 2016 Christopher M. Biwer, Alexander Harvey Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Creates a DAX for a parameter estimation workflow.
"""

import argparse
import h5py
import logging
import os
import Pegasus.DAX3 as dax
import pycbc.workflow as wf
import pycbc.workflow.inference_followups as inffu
import pycbc.workflow.minifollowups as mini
import pycbc.workflow.pegasus_workflow as wdax
import pycbc.version
import socket
import sys
from pycbc_glue import segments
from pycbc import results
from pycbc.results import layout
from pycbc.types import MultiDetOptionAction
from pycbc.types import MultiDetOptionAppendAction
from pycbc.workflow import WorkflowConfigParser

def to_file(path, ifo=None):
    """ Takes a str and returns a pycbc.workflow.pegasus_workflow.File
    instance.
    """
    fil = wdax.File(os.path.basename(path))
    fil.ifo = ifo
    path = os.path.abspath(path)
    fil.PFN(path, "local")
    return fil

def symlink_path(f, path):
    """ Symlinks a path.
    """
    if f is None:
        return
    try:
        os.symlink(f.storage_path, os.path.join(path, f.name))
    except OSError:
        pass

# command line parser
parser = argparse.ArgumentParser(description=__doc__[1:])

# version option
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg, 
                    help="Prints version information.")

# workflow options
parser.add_argument("--workflow-name", default="my_unamed_inference_run",
                    help="Name of the workflow to append in various places.")
parser.add_argument("--tags", nargs="+", default=[],
                    help="Tags to apend in various places.")
parser.add_argument("--output-dir", default=None,
                    help="Path to output directory.")
parser.add_argument("--output-map", required=True,
                    help="Path to output map file.")
parser.add_argument("--output-file", required=True,
                    help="Path to DAX file.")

# input workflow file options
parser.add_argument("--bank-file", default=None,
                    help="HDF format template bank file.")
parser.add_argument("--statmap-file", default=None,
                    help="HDF format clustered coincident trigger "
                         "result file.")
parser.add_argument("--single-detector-triggers", nargs="+", default=None,
                    action=MultiDetOptionAction,
                    help="HDF format merged single detector trigger files.")

# analysis time option
# only required if not using input workflow file options
parser.add_argument("--gps-end-time", type=float, nargs="+", default=None,
                    help="Times to analyze. If given workflow files then "
                         "this option is ignored.")

# input configuration file options
parser.add_argument("--inference-config-file", type=str, required=True,
                    help="workflow.WorkflowConfigParser parsable file with "
                         "proir information.")
parser.add_argument("--prior-section", type=str, default="prior",
                    help="Name of the section in inference configuration file "
                         "that contains priors.")

# input frame files
parser.add_argument("--frame-files", nargs="+", default=None,
                    action=MultiDetOptionAppendAction,
                    help="GWF frame files to use.")

# add option groups
wf.add_workflow_command_line_group(parser)

# parser command line
opts = parser.parse_args()

# log to terminal until we know where the path to log output file
log_format = "%(asctime)s:%(levelname)s : %(message)s"
logging.basicConfig(format=log_format, level=logging.INFO)

# create workflow and sub-workflows
container = wf.Workflow(opts, opts.workflow_name)
workflow = wf.Workflow(opts, opts.workflow_name + "-main")
finalize_workflow = wf.Workflow(opts, opts.workflow_name + "-finalization")

# sections for output HTML pages
rdir = layout.SectionNumber("results",
                            ["detector_sensitivity", "priors", "posteriors",
                             "samples", "workflow"])

# make data output and results directories
wf.makedir(opts.output_dir)
wf.makedir(rdir.base)
wf.makedir(rdir["workflow"])

# create files for workflow log
log_file_txt = wf.File(workflow.ifos, "workflow-log", workflow.analysis_time,
                      extension=".txt", directory=rdir["workflow"])
log_file_html = wf.File(workflow.ifos, "WORKFLOW-LOG", workflow.analysis_time,
                        extension=".html", directory=rdir["workflow"])

# switch saving log to file
logging.basicConfig(format=log_format, level=logging.INFO,
                    filename=log_file_txt.storage_path, filemode="w")
log_file = logging.FileHandler(filename=log_file_txt.storage_path, mode="w")
log_file.setLevel(logging.INFO)
formatter = logging.Formatter(log_format)
log_file.setFormatter(formatter)
logging.getLogger("").addHandler(log_file)
logging.info("Created log file %s" % log_file_txt.storage_path)

# sanity check that user is either using workflow or command line
workflow_options = opts.bank_file != None \
                   and opts.statmap_file != None \
                   and opts.single_detector_triggers != None
if not workflow_options and not (opts.gps_end_time != None):
    raise ValueError("Must use either workflow options or --gps-end-time")

# typecast str from command line to File instances
config_file = to_file(opts.inference_config_file)

# sections for output HTML pages
rdir = layout.SectionNumber("results",
                            ["detector_sensitivity", "priors", "posteriors",
                             "samples", "workflow"])

wf.makedir(rdir.base)
wf.makedir(rdir["workflow"])

# create file for workflow log
wf_log_file = wf.File(workflow.ifos, "workflow-log", workflow.analysis_time,
                      extension=".txt",
                      directory=rdir["workflow"])

# Create the final log file
log_file_html = wf.File(workflow.ifos, "WORKFLOW-LOG", workflow.analysis_time,
                                           extension=".html",
                                           directory=rdir["workflow"])

logging.basicConfig(format="%(asctime)s:%(levelname)s : %(message)s",
                    filename=wf_log_file.storage_path,
                    level=logging.INFO,
                    filemode="w")

logfile = logging.FileHandler(filename=wf_log_file.storage_path, mode="w")
logfile.setLevel(logging.INFO)
formatter = logging.Formatter("%(asctime)s:%(levelname)s : %(message)s")
logfile.setFormatter(formatter)
logging.getLogger("").addHandler(logfile)
logging.info("Created log file %s" % wf_log_file.storage_path)

# sanity check that user is either using workflow or command line
workflow_options = opts.bank_file != None \
                   and opts.statmap_file != None \
                   and opts.single_detector_triggers != None
if not workflow_options and not (opts.gps_end_time != None):
    raise ValueError("Must use either workflow options or --gps-end-time")

# typecast str from command line to File instances
config_file = to_file(opts.inference_config_file)

# if using workflow files to find analysis times
if workflow_options:

    # typecast str from command line to File instances
    tmpltbank_file = to_file(opts.bank_file)
    coinc_file = to_file(opts.statmap_file)
    single_triggers = []
    for ifo in opts.single_detector_triggers:
        fname = opts.single_detector_triggers[ifo]
        single_triggers.append(to_file(fname, ifo=ifo))

    # get number of events to analyze
    num_events = int(workflow.cp.get_opt_tags("workflow-inference",
                                              "num-events", ""))

    # get detection statistic from statmap file
    # if less events that requested to analyze in config file
    # then analyze all events
    f = h5py.File(opts.statmap_file, "r")
    stat = f["foreground/stat"][:]
    if len(stat) < num_events:
        num_events = len(stat)

    # get index for sorted detection statistic
    sorting = stat.argsort()[::-1]

    # get times for loudest events
    times = {
        f.attrs["detector_1"]: f["foreground/time1"][:][sorting][0:num_events],
        f.attrs["detector_2"]: f["foreground/time2"][:][sorting][0:num_events],
    }

# else get analysis times from command line
else:
    times = dict([(ifo,opts.gps_end_time) for ifo in workflow.ifos])
    num_events = len(opts.gps_end_time)

# construct Executable for running sampler
inference_exe = wf.Executable(workflow.cp, "inference", ifos=workflow.ifos,
                              out_dir=opts.output_dir)

# read inference configuration file
cp = WorkflowConfigParser([opts.inference_config_file])

# get channel names
channel_names = {}
for ifo in workflow.ifos:
    channel_names[ifo] = workflow.cp.get_opt_tags(
                               "workflow", "%s-channel-name" % ifo.lower(), "")
channel_names_str = " ".join([key + ":" + val for key, val in \
                              channel_names.iteritems()])

# figure out what parameters user wants to plot from workflow configuration
plot_parameters = {}
for option in workflow.cp.options("workflow-inference"):
    if option.startswith("plot-1d-"):
        group = option.replace("plot-1d-", "").replace("-", "_")
        plot_parameters[group] = workflow.cp.get_opt_tag(
                                "workflow", option, "inference").split(" ")
all_parameters = [param for group in plot_parameters.values()
                  for param in group]

# loop over number of loudest events to be analyzed
for num_event in range(num_events):

    # set GPS times for reading in data around the event
    avg_end_time = sum([end_times[num_event] \
                          for end_times in times.values()]) / len(times.keys())
    seconds_before_time = int(workflow.cp.get_opt_tags(
                                            "workflow-inference",
                                            "data-seconds-before-trigger", ""))
    seconds_after_time = int(workflow.cp.get_opt_tags(
                                            "workflow-inference",
                                            "data-seconds-after-trigger", ""))
    gps_start_time = int(avg_end_time) - seconds_before_time
    gps_end_time = int(avg_end_time) + seconds_after_time

    # get dict of segments for each IFO
    seg_dict = {ifo : segments.segmentlist([segments.segment(
                      gps_start_time, gps_end_time)]) for ifo in workflow.ifos}

    # get frame files from command line or datafind server
    frame_files = wf.FileList([])
    if opts.frame_files:
        for ifo in workflow.ifos:
            for path in opts.frame_files[ifo]:
                frame_file = wf.File(
                                ifo, "FRAME",
                                segments.segment(gps_start_time, gps_end_time),
                                file_url="file://" + path)
                frame_file.PFN(frame_file.storage_path, site="local")
                frame_files.append(frame_file)
    else:
        frame_files, _, _, _ = wf.setup_datafind_workflow(workflow, seg_dict,
                                                          "datafind")

    # make node for running sampler
    node = inference_exe.create_node()
    node.add_opt("--instruments", " ".join(workflow.ifos))
    node.add_opt("--gps-start-time", gps_start_time)
    node.add_opt("--gps-end-time", gps_end_time)
    node.add_multiifo_input_list_opt("--frame-files", frame_files)
    node.add_opt("--channel-name", channel_names_str)
    node.add_input_opt("--config-file", config_file)
    analysis_time = segments.segment(gps_start_time, gps_end_time)
    inference_file = node.new_output_file_opt(analysis_time, ".hdf",
                                              "--output-file",
                                              tags=[str(num_event)])

    # add node to workflow
    workflow += node

    # files that prints information about the search event
    if workflow_options:
        coinc_table_files = [mini.make_coinc_info(workflow, single_triggers,
                                         tmpltbank_file, coinc_file,
                                         num_event, rdir.base,
                                         tags=opts.tags + [str(num_event)])]

    # files for posteriors summary subsection
    base = "posteriors"
    post_table_files = inffu.make_inference_summary_table(
                          workflow, inference_file, rdir[base],
                          variable_args=all_parameters,
                          analysis_seg=analysis_time,
                          tags=opts.tags + [str(num_event)])
    post_files = inffu.make_inference_posterior_plot(
                          workflow, inference_file, rdir[base],
                          parameters=all_parameters,
                          analysis_seg=analysis_time,
                          tags=opts.tags + [str(num_event)])
    layout.single_layout(rdir[base], post_table_files + post_files)

    # files for posteriors 1d_posteriors subsection
    groups = plot_parameters.keys()
    groups.sort()
    for group in groups:
        parameters = plot_parameters[group]
        base = "posteriors/1d_{}_posteriors".format(group)
        post_1d_files = inffu.make_inference_1d_posterior_plots(
                          workflow, inference_file, rdir[base],
                          parameters=parameters,
                          analysis_seg=analysis_time,
                          tags=opts.tags + [str(num_event)])
        layout.group_layout(rdir[base], post_1d_files)

    # files for samples summary subsection
    base = "samples"
    samples_files = []
    for parameter in all_parameters:
        samples_files += inffu.make_inference_samples_plot(
                          workflow, inference_file, rdir[base],
                          parameters=[parameter], analysis_seg=analysis_time,
                          tags=opts.tags + [str(num_event), parameter])
    layout.group_layout(rdir[base], samples_files)

    # files for samples acceptance_rate subsection
    base = "samples/acceptance_rate"
    acceptance_files = inffu.make_inference_acceptance_rate_plot(
                          workflow, inference_file, rdir[base],
                          analysis_seg=analysis_time,
                          tags=opts.tags + [str(num_event)])
    layout.single_layout(rdir[base], acceptance_files)

    # files for detector_sensitivity summary subsection
    base = "detector_sensitivity"
    psd_file = wf.make_spectrum_plot(workflow, [inference_file], rdir[base],
                                     tags=opts.tags + [str(num_event)])
    layout.single_layout(rdir[base], [psd_file])

# files for priors summary section
base = "priors"
prior_files = []
for subsection in cp.get_subsections(opts.prior_section):
    prior_files += inffu.make_inference_prior_plot(
                          workflow, config_file, rdir[base],
                          analysis_seg=workflow.analysis_time,
                          sections=[opts.prior_section + "-" + subsection],
                          tags=opts.tags + [subsection])
layout.group_layout(rdir[base], prior_files)

# files for overall summary section
summ_files = coinc_table_files if workflow_options else []
summ_files = post_table_files + post_files
for summ_file in summ_files:
    symlink_path(summ_file, rdir.base)
layout.single_layout(rdir.base, summ_files)

# create versioning HTML pages
results.create_versioning_page(rdir["workflow/version"], container.cp)

# create node for making HTML pages
wf.make_results_web_page(finalize_workflow,
                         os.path.join(os.getcwd(), rdir.base))

# add sub-workflows to workflow
container += workflow
container += finalize_workflow

# make finalize sub-workflow depend on main sub-workflow
dep = dax.Dependency(parent=workflow.as_job, child=finalize_workflow.as_job)
container._adag.addDependency(dep)

# write dax
container.save(filename=opts.output_file, output_map_path=opts.output_map)

# save workflow configuration file
base = rdir["workflow/configuration"]
wf.makedir(base)
wf_ini = workflow.save_config("workflow.ini", base, container.cp)
layout.single_layout(base, wf_ini)

# save prior configuration file
base = rdir["priors/configuration"]
wf.makedir(base)
prior_ini = workflow.save_config("priors.ini", base, cp)
layout.single_layout(base, prior_ini)

# close the log and flush to the html file
logging.shutdown()
with open (log_file_txt.storage_path, "r") as log_file:
    log_data = log_file.read()
log_str = """
<p>Workflow generation script created workflow in output directory: %s</p>
<p>Workflow name is: %s</p>
<p>Workflow generation script run on host: %s</p>
<pre>%s</pre>
""" % (os.getcwd(), opts.workflow_name, socket.gethostname(), log_data)
kwds = {"title" : "Workflow Generation Log",
        "caption" : "Log of the workflow script %s" % sys.argv[0],
        "cmd" : " ".join(sys.argv)}
results.save_fig_with_metadata(log_str, log_file_html.storage_path, **kwds)
layout.single_layout(rdir["workflow"], ([log_file_html]))
