#! /usr/bin/env python

# Copyright (C) 2016 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Runs a sampler to find the posterior distributions.
"""

import os
import argparse
import logging
import numpy
import pycbc
import pycbc.opt
import pycbc.weave
import random
import pycbc
from pycbc import fft
from pycbc import inference
from pycbc import psd
from pycbc import scheme
from pycbc import strain
from pycbc import types
from pycbc import waveform
from pycbc.io.inference_hdf import InferenceFile
from pycbc.workflow import WorkflowConfigParser
from pycbc.inference import option_utils

def select_waveform_generator(approximant):
    """ Returns the generator for the approximant.
    """
    if approximant in waveform.fd_approximants():
        return waveform.FDomainCBCGenerator
    elif approximant in waveform.td_approximants():
        return waveform.TDomainCBCGenerator
    elif approximant in waveform.ringdown_fd_approximants:
        if approximant=='FdQNM':
            return waveform.FDomainRingdownGenerator
        elif approximant=='FdQNMmultiModes':
            return waveform.FDomainMultiModeRingdownGenerator
    elif approximant in waveform.ringdown_td_approximants:
        raise ValueError("Time domain ringdowns not supported")
    else:
        raise ValueError("%s is not a valid approximant."%approximant)


def convert_liststring_to_list(lstring):
    """ Checks if an argument of the configuration file is a string of a list
    and returns the corresponding list (of strings)
    """
    if lstring[0]=='[' and lstring[-1]==']':
        lvalue = [str(lstring[1:-1].split(',')[n].strip().strip("'"))
                      for n in range(len(lstring[1:-1].split(',')))]
    return lvalue


def _gates_from_cli(opts, gate_opt):
    """Parses the given `gate_opt` into something understandable by
    `strain.gate_data`.
    """
    gates = {}
    if getattr(opts, gate_opt) is None:
        return gates
    for gate in getattr(opts, gate_opt):
        try:
            ifo, central_time, half_dur, taper_dur = gate.split(':')
            central_time = float(central_time)
            half_dur = float(half_dur)
            taper_dur = float(taper_dur)
        except ValueError:
            raise ValueError("--gate {} not formatted correctly; ".format(
                gate) + "see help")
        try:
            gates[ifo].append((central_time, half_dur, taper_dur))
        except KeyError:
            gates[ifo] = [(central_time, half_dur, taper_dur)]
    return gates


def gates_from_cli(opts):
    """Parses the --gate option into something understandable by
    `strain.gate_data`.
    """
    return _gates_from_cli(opts, 'gate')


def psd_gates_from_cli(opts):
    """Parses the --psd-gate option into something understandable by
    `strain.gate_data`.
    """
    return _gates_from_cli(opts, 'psd_gate')


def apply_gates_to_td(strain_dict, gates):
    """Applies the given dictionary of gates to the given dictionary of
    strain.

    Parameters
    ----------
    strain_dict : dict
        Dictionary of time-domain strain, keyed by the ifos.
    gates : dict
        Dictionary of gates. Keys should be the ifo to apply the data to,
        values are a tuple giving the central time of the gate, the half
        duration, and the taper duration.

    Returns
    -------
    dict
        Dictionary of time-domain strain with the gates applied.
    """
    # copy data to new dictionary
    outdict = dict(strain_dict.items())
    for ifo in gates:
        logging.info("Gating {} strain".format(ifo))
        outdict[ifo] = strain.gate_data(outdict[ifo], gates[ifo])
    return outdict


def apply_gates_to_fd(stilde_dict, gates):
    """Applies the given dictionary of gates to the given dictionary of
    strain in the frequency domain.

    Gates are applied by IFFT-ing the strain data to the time domain, applying
    the gate, then FFT-ing back to the frequency domain.

    Parameters
    ----------
    stilde_dict : dict
        Dictionary of frequency-domain strain, keyed by the ifos.
    gates : dict
        Dictionary of gates. Keys should be the ifo to apply the data to,
        values are a tuple giving the central time of the gate, the half
        duration, and the taper duration.

    Returns
    -------
    dict
        Dictionary of frequency-domain strain with the gates applied.
    """
    # copy data to new dictionary
    outdict = dict(stilde_dict.items())
    # create a time-domin strain dictionary to apply the gates to
    strain_dict = dict([[ifo, outdict[ifo].to_timeseries()] for ifo in gates])
    # apply gates and fft back to the frequency domain
    for ifo,d in apply_gates_to_td(strain_dict, gates).items():
        outdict[ifo] = d.to_frequencyseries()
    return outdict


# command line usage
parser = argparse.ArgumentParser(usage=__file__ + " [--options]",
                                 description=__doc__)

# add data options
parser.add_argument("--instruments", type=str, nargs="+", required=True,
                    help="IFOs, eg. H1 L1.")
parser.add_argument("--frame-type", type=str, nargs="+",
                    action=types.MultiDetOptionAction,
                    metavar="IFO:FRAME_TYPE",
                    help="Frame type for each IFO.")
parser.add_argument("--low-frequency-cutoff", type=float, required=True,
                    help="Low frequency cutoff for each IFO.")
parser.add_argument("--psd-start-time", type=float, default=None,
                    help="Start time to use for PSD estimation if different "
                         "from analysis.")
parser.add_argument("--psd-end-time", type=float, default=None,
                    help="End time to use for PSD estimation if different "
                         "from analysis.")
parser.add_argument("--gate", nargs="+", type=str,
                    metavar="IFO:CENTRALTIME:HALFDUR:TAPERDUR",
                    help="Apply one or more gates to the data before "
                         "filtering.")
parser.add_argument("--gate-overwhitened", action="store_true", default=False,
                    help="Overwhiten data first, then apply the gates "
                         "specified in --gate. Overwhitening allows for "
                         "sharper tapers to be used, since lines are not "
                         "blurred.")
parser.add_argument("--psd-gate", nargs="+", type=str,
                    metavar="IFO:CENTRALTIME:HALFDUR:TAPERDUR",
                    help="Apply one or more gates to the data used for "
                         "computing the PSD. Gates are applied prior to "
                         "FFT-ing the data for PSD estimation.")

# add inference options
parser.add_argument("--likelihood-evaluator", required=True,
                    choices=inference.likelihood_evaluators.keys(),
                    help="Evaluator class to use to calculate the likelihood.")
parser.add_argument("--seed", type=int, default=0,
                    help="Seed to use for the random number generator that "
                         "initially distributes the walkers. Default is 0.")

# add sampler options
option_utils.add_sampler_option_group(parser)

# add config options
parser.add_argument("--config-files", type=str, nargs="+", required=True,
                    help="A file parsable by "
                         "pycbc.workflow.WorkflowConfigParser.")
parser.add_argument("--config-overrides", type=str, nargs="+", default=None,
                    metavar="SECTION:OPTION:VALUE",
                    help="List of section:option:value combinations to add "
                         "into the configuration file.")

# output options
parser.add_argument("--output-file", type=str, required=True,
                    help="Output file path.")
parser.add_argument("--force", action="store_true", default=False,
                    help="If the output-file already exists, overwrite it. "
                         "Otherwise, an OSError is raised.")
parser.add_argument("--save-strain", action="store_true", default=False,
                    help="Save the conditioned strain time series to the "
                         "output file. If gate-overwhitened, this is done "
                         "before all gates have been applied.")
parser.add_argument("--save-stilde", action="store_true", default=False,
                    help="Save the conditioned strain frequency series to "
                         "the output file. This is done after all gates have "
                         "been applied.")
parser.add_argument("--save-psd", action="store_true", default=False,
                    help="Save the psd of each ifo to the output file.")
parser.add_argument("--checkpoint-interval", type=int, default=None,
                    help="Number of iterations to take before saving new "
                         "samples to file.")
parser.add_argument("--checkpoint-fast", action="store_true",
                    help="Do not calculate derived data (eg. ACL or evidence) "
                         "after each checkpoint. Calculate them at the end.")
 
# verbose option
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Print logging messages.")

# add module pre-defined options
fft.insert_fft_option_group(parser)
pycbc.opt.insert_optimization_option_group(parser)
psd.insert_psd_option_group_multi_ifo(parser)
scheme.insert_processing_option_group(parser)
strain.insert_strain_option_group_multi_ifo(parser)
pycbc.weave.insert_weave_option_group(parser)

# parse command line
opts = parser.parse_args()

# verify options are sane
fft.verify_fft_options(opts, parser)
pycbc.opt.verify_optimization_options(opts, parser)
#psd.verify_psd_options(opts, parser)
scheme.verify_processing_options(opts, parser)
#strain.verify_strain_options(opts, parser)
pycbc.weave.verify_weave_options(opts, parser)

# check for the output file
if os.path.exists(opts.output_file) and not opts.force:
    raise OSError("output-file already exists; use --force if you wish to "
                  "overwrite it.")
# create the file for adding to
fp = InferenceFile(opts.output_file, "w")
fp.close()

# setup log
pycbc.init_logging(opts.verbose)

# set the seed
numpy.random.seed(opts.seed)
logging.info("Using seed %i" %(opts.seed))

# change measure level for FFT to 0
fft.fftw.set_measure_level(0)

# get scheme
ctx = scheme.from_cli(opts)
fft.from_cli(opts)

# get gates to apply
gates = gates_from_cli(opts)
psd_gates = psd_gates_from_cli(opts)

# get strain time series
strain_dict = strain.from_cli_multi_ifos(opts, opts.instruments,
                                         precision="double")
# apply gates if not waiting to overwhiten
if not opts.gate_overwhitened:
    logging.info("Applying gates to strain data")
    strain_dict = apply_gates_to_td(strain_dict, gates)

# get strain time series to use for PSD estimation
# if user has not given the PSD time options then use same data as analysis
if opts.psd_start_time and opts.psd_end_time:
    logging.info("Will generate a different time series for PSD estimation")
    psd_opts = opts
    psd_opts.gps_start_time = psd_opts.psd_start_time
    psd_opts.gps_end_time = psd_opts.psd_end_time
    psd_strain_dict = strain.from_cli_multi_ifos(psd_opts, opts.instruments,
                                                precision="double")
    # apply any gates
    logging.info("Applying gates to PSD data")
    psd_strain_dict = apply_gates_to_td(psd_strain_dict, psd_gates)

elif opts.psd_start_time or opts.psd_end_time:
    raise ValueError("Must give --psd-start-time and --psd-end-time")
else:
    psd_strain_dict = strain_dict

with ctx:

    # FFT strain and save each of the length of the FFT, delta_f, and
    # low frequency cutoff to a dict
    logging.info("FFT strain")
    stilde_dict = {}
    length_dict = {}
    delta_f_dict = {}
    low_frequency_cutoff_dict = {}
    for ifo in opts.instruments:
        stilde_dict[ifo] = strain_dict[ifo].to_frequencyseries()
        length_dict[ifo] = len(stilde_dict[ifo])
        delta_f_dict[ifo] = stilde_dict[ifo].delta_f
        low_frequency_cutoff_dict[ifo] = opts.low_frequency_cutoff

    # get PSD as frequency series
    psd_dict = psd.from_cli_multi_ifos(opts, length_dict, delta_f_dict,
                               low_frequency_cutoff_dict, opts.instruments,
                               strain_dict=psd_strain_dict, precision="double")

    # apply any gates to overwhitened data, if desired
    if opts.gate_overwhitened and opts.gate is not None:
        logging.info("Applying gates to overwhitened data")
        # overwhiten the data
        for ifo in gates:
            stilde_dict[ifo] /= psd_dict[ifo]
        stilde_dict = apply_gates_to_fd(stilde_dict, gates)
        # unwhiten the data for the likelihood generator
        for ifo in gates:
            stilde_dict[ifo] *= psd_dict[ifo]

    # save PSD
    if opts.save_psd:
        # apply dynamic range factor for saving PSDs since
        # plotting code expects it
        logging.info("Saving PSDs")
        psd_dyn_dict = {}
        for key,val in psd_dict.iteritems():
             psd_dyn_dict[key] = types.FrequencySeries(
                                        psd_dict[key] * pycbc.DYN_RANGE_FAC**2,
                                        delta_f=psd_dict[key].delta_f)
        with InferenceFile(opts.output_file, "a") as fp:
            fp.write_psd(psds=psd_dyn_dict,
                          low_frequency_cutoff=low_frequency_cutoff_dict)

    # save stilde
    if opts.save_stilde:
        with InferenceFile(opts.output_file, "a") as fp:
            fp.write_stilde(stilde_dict)

    # save strain if desired
    if opts.save_strain:
        with InferenceFile(opts.output_file, "a") as fp:
            fp.write_strain(strain_dict)

    # read configuration file
    logging.info("Reading configuration file")
    if opts.config_overrides is not None:
        overrides = [override.split(":") for override in opts.config_overrides]
    else:
        overrides = None
    cp = WorkflowConfigParser(opts.config_files, overrides)

    # sanity check that each parameter in [variable_args] has a priors section
    variable_args = cp.options("variable_args")
    subsections = cp.get_subsections("prior")
    tags = numpy.concatenate([tag.split("+") for tag in subsections])
    if not any(param in tags for param in variable_args):
        raise KeyError("You are missing a priors section in the config file.")

    # get parameters that do not change in sampler
    static_args = dict([(key,cp.get_opt_tags("static_args",key,[])) \
                                         for key in cp.options("static_args")])
    for key,val in static_args.iteritems():
        try:
            static_args[key] = float(val)
            continue
        except:
            pass
        try:
            static_args[key] = convert_liststring_to_list(val) 
        except:
            pass

    # get prior distribution for each variable parameter
    logging.info("Setting up priors for each parameter")
    distributions = inference.read_distributions_from_config(cp, "prior")

    # construct class that will return the prior
    prior = inference.PriorEvaluator(variable_args, *distributions)

    # select generator that will generate waveform
    # for likelihood evaluator
    logging.info("Setting up sampler")
    generator_function = select_waveform_generator(static_args["approximant"])

    # construct class that will generate waveforms
    generator = waveform.FDomainDetFrameGenerator(generator_function,
                        epoch=stilde_dict.values()[0].epoch,
                        variable_args=variable_args,
                        detectors=opts.instruments,
                        delta_f=delta_f_dict.values()[0], **static_args)

    # construct class that will return the natural logarithm of likelihood
    likelihood = inference.likelihood_evaluators[opts.likelihood_evaluator](
                        generator, stilde_dict, opts.low_frequency_cutoff,
                        psds=psd_dict, prior=prior)

    # create sampler that will run
    sampler = option_utils.sampler_from_cli(opts, likelihood)

    # get distribution to draw from for each walker initial position
    logging.info("Setting walkers initial conditions for varying parameters")

    # check if user wants to use different distributions than prior for setting
    # walkers initial positions
    if len(cp.get_subsections("initial")):
        initial_distributions = inference.read_distributions_from_config(cp,
            section="initial")
    else:
        initial_distributions = inference.read_distributions_from_config(cp,
            section="prior")

    # set the initial positions
    sampler.set_p0(initial_distributions)

    # setup checkpointing
    if opts.checkpoint_interval:

        # determine intervals to run sampler until we save the new samples 
        intervals = [i for i in range(0, opts.niterations, opts.checkpoint_interval)]

        # determine if there is a small bit at the end
        remainder = opts.niterations % opts.checkpoint_interval
        if remainder:
            intervals += [intervals[-1]+remainder]
        else:
            intervals += [opts.niterations]

    # if not checkpointing then set intervals to run sampler in one call
    else:
        intervals = [0, opts.niterations]

    intervals = numpy.array(intervals)

    # check if user wants to skip the burn in
    if not opts.skip_burn_in:
        logging.info("Burn in")
        sampler.burn_in()
        n_burnin = sampler.burn_in_iterations
        logging.info("Used %i burn in samples" % n_burnin)
        with InferenceFile(opts.output_file, "a") as fp:
            # write the burn in results
            sampler.write_results(fp, max_iterations=opts.niterations+n_burnin)
    else:
        n_burnin = 0


    # increase the intervals to account for the burn in
    intervals += n_burnin

    # loop over number of checkpoints
    for i,start in enumerate(intervals[:-1]):

        end = intervals[i+1]

        # run sampler and set initial values to None so that sampler
        # picks up from where it left off next call
        logging.info("Running sampler for %i to %i out of %i iterations"%(
            start-n_burnin,end-n_burnin,opts.niterations))
        sampler.run(end-start)

        # write new samples
        with InferenceFile(opts.output_file, "a") as fp:

            logging.info("Writing results to file")
            sampler.write_results(fp, max_iterations=opts.niterations+n_burnin)

            if not opts.checkpoint_fast or end == opts.niterations:

                # compute the acls and write
                logging.info("Computing acls")
                sampler.write_acls(fp, sampler.compute_acls(fp))

                # compute evidence, if supported
                try:
                    lnz, dlnz = sampler.calculate_logevidence(fp)
                    logging.info("Saving evidence")
                    sampler.write_logevidence(fp, lnz, dlnz)
                except NotImplementedError:
                    pass

# exit
logging.info("Done")
