#! /usr/bin/env python

# Copyright (C) 2016 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import argparse
import logging
import numpy
import pycbc.opt
import pycbc.weave
import random
from pycbc import DYN_RANGE_FAC, fft, inference, psd, scheme, strain, waveform
from pycbc.io.inference_hdf import InferenceFile
from pycbc.types import FrequencySeries, MultiDetOptionAction
from pycbc.workflow import WorkflowConfigParser
from pycbc.inference import option_utils

def select_waveform_generator(approximant):
    """ Returns the generator for the approximant.
    """
    if approximant in waveform.fd_approximants():
        return waveform.FDomainCBCGenerator
    elif approximant in waveform.td_approximants():
        return waveform.TDomainCBCGenerator
    elif approximant in waveform.ringdown_fd_approximants:
        if approximant=='FdQNM':
            return waveform.FDomainRingdownGenerator
        elif approximant=='FdQNMmultiModes':
            return waveform.FDomainMultiModeRingdownGenerator
    elif approximant in waveform.ringdown_td_approximants:
        raise ValueError("Time domain ringdowns not supported")
    else:
        raise ValueError("%s is not a valid approximant."%approximant)

def convert_liststring_to_list(lstring):
    """ Checks if an argument of the configuration file is a string of a list
    and returns the corresponding list (of strings)
    """
    if lstring[0]=='[' and lstring[-1]==']':
        lvalue = [str(lstring[1:-1].split(',')[n].strip().strip("'"))
                      for n in range(len(lstring[1:-1].split(',')))]
    return lvalue

# command line usage
parser = argparse.ArgumentParser(usage="pycbc_inference [--options]",
                  description="Runs an sampler to find the posterior.")

# add data options
parser.add_argument("--instruments", type=str, nargs="+", required=True,
    help="IFOs, eg. H1 L1.")
parser.add_argument("--frame-type", type=str, nargs="+",
    action=MultiDetOptionAction, metavar="IFO:FRAME_TYPE",
    help="Frame type for each IFO.")
parser.add_argument("--low-frequency-cutoff", type=float, required=True,
    help="Low frequency cutoff for each IFO.")
parser.add_argument("--psd-start-time", type=float, default=None,
    help="Start time to use for PSD estimation if different from analysis.")
parser.add_argument("--psd-end-time", type=float, default=None,
    help="End time to use for PSD estimation if different from analysis.")

# add inference options
parser.add_argument("--likelihood-evaluator", required=True,
    choices=inference.likelihood_evaluators.keys(),
    help="Evaluator class to use to calculate the likelihood.")
parser.add_argument("--seed", type=int, default=0,
    help="Seed to use for the random number generator that initially "
         "distributes the walkers. Default is 0.")
# add sampler options
option_utils.add_sampler_option_group(parser)

# add config options
parser.add_argument("--config-files", type=str, nargs="+", required=True,
    help="A file parsable by pycbc.workflow.WorkflowConfigParser.")

# output options
parser.add_argument("--output-file", type=str, required=True,
    help="Output file path.")
parser.add_argument("--checkpoint-interval", type=int, default=None,
    help="Number of iterations to take before saving new samples to file.")

# verbose option
parser.add_argument("--verbose", action="store_true", default=False,
    help="")

# add module pre-defined options
fft.insert_fft_option_group(parser)
pycbc.opt.insert_optimization_option_group(parser)
psd.insert_psd_option_group_multi_ifo(parser)
scheme.insert_processing_option_group(parser)
strain.insert_strain_option_group_multi_ifo(parser)
pycbc.weave.insert_weave_option_group(parser)

# parse command line
opts = parser.parse_args()

# verify options are sane
fft.verify_fft_options(opts, parser)
pycbc.opt.verify_optimization_options(opts, parser)
#psd.verify_psd_options(opts, parser)
scheme.verify_processing_options(opts, parser)
#strain.verify_strain_options(opts, parser)
pycbc.weave.verify_weave_options(opts, parser)

# setup log
if opts.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format="%(asctime)s : %(message)s", level=log_level)

# set the seed
numpy.random.seed(opts.seed)
logging.info("Using seed %i" %(opts.seed))

# change measure level for FFT to 0
fft.fftw.set_measure_level(0)

# get scheme
ctx = scheme.from_cli(opts)
fft.from_cli(opts)

# get strain time series
strain_dict = strain.from_cli_multi_ifos(opts, opts.instruments,
                                         precision="double")

# get strain time series to use for PSD estimation
# if user has not given the PSD time options then use same data as analysis
if opts.psd_start_time and opts.psd_end_time:
    logging.info("Will generate a different time series for PSD estimation")
    psd_opts = opts
    psd_opts.gps_start_time = psd_opts.psd_start_time
    psd_opts.gps_end_time = psd_opts.psd_end_time
    psd_strain_dict = strain.from_cli_multi_ifos(psd_opts, opts.instruments,
                                                precision="double")
elif opts.psd_start_time or opts.psd_end_time:
    raise ValueError("Must give --psd-start-time and --psd-end-time")
else:
    psd_strain_dict = strain_dict

with ctx:

    # FFT strain and save each of the length of the FFT, delta_f, and
    # low frequency cutoff to a dict
    logging.info("FFT strain")
    stilde_dict = {}
    length_dict = {}
    delta_f_dict = {}
    low_frequency_cutoff_dict = {}
    for ifo in opts.instruments:
        stilde_dict[ifo] = strain_dict[ifo].to_frequencyseries()
        length_dict[ifo] = len(stilde_dict[ifo])
        delta_f_dict[ifo] = stilde_dict[ifo].delta_f
        low_frequency_cutoff_dict[ifo] = opts.low_frequency_cutoff

    # get PSD as frequency series
    psd_dict = psd.from_cli_multi_ifos(opts, length_dict, delta_f_dict,
                               low_frequency_cutoff_dict, opts.instruments,
                               strain_dict=psd_strain_dict, precision="double")

    # apply dynamic range factor for saving PSDs since plotting code expects it
    logging.info("Saving PSDs")
    psd_dyn_dict = {}
    for key,val in psd_dict.iteritems():
         psd_dyn_dict[key] = FrequencySeries(psd_dict[key]*DYN_RANGE_FAC**2,
                                             delta_f=psd_dict[key].delta_f)

    # save PSD
    with InferenceFile(opts.output_file, "w") as fp:
        fp.write_psds(psds=psd_dyn_dict,
                      low_frequency_cutoff=low_frequency_cutoff_dict)

    # read configuration file
    logging.info("Reading configuration file")
    cp = WorkflowConfigParser(opts.config_files)

    # sanity check that each parameter in [variable_args] has a priors section
    variable_args = cp.options("variable_args")
    subsections = cp.get_subsections("prior")
    tags = numpy.concatenate([tag.split("+") for tag in subsections])
    if not any(param in tags for param in variable_args):
        raise KeyError("You are missing a priors section in the config file.")

    # get parameters that do not change in sampler
    static_args = dict([(key,cp.get_opt_tags("static_args",key,[])) \
                                         for key in cp.options("static_args")])
    for key,val in static_args.iteritems():
        try:
            static_args[key] = float(val)
            continue
        except:
            pass
        try:
            static_args[key] = convert_liststring_to_list(val) 
        except:
            pass

    # get prior distribution for each variable parameter
    logging.info("Setting up priors for each parameter")
    distributions = inference.read_distributions_from_config(cp, "prior")

    # construct class that will return the prior
    prior = inference.PriorEvaluator(variable_args, *distributions)

    # select generator that will generate waveform
    # for likelihood evaluator
    logging.info("Setting up sampler")
    generator_function = select_waveform_generator(static_args["approximant"])

    # construct class that will generate waveforms
    generator = waveform.FDomainDetFrameGenerator(generator_function,
                        epoch=stilde_dict.values()[0].epoch,
                        variable_args=variable_args,
                        detectors=opts.instruments,
                        delta_f=delta_f_dict.values()[0], **static_args)

    # construct class that will return the natural logarithm of likelihood
    likelihood = inference.likelihood_evaluators[opts.likelihood_evaluator](
                        generator, stilde_dict, opts.low_frequency_cutoff,
                        psds=psd_dict, prior=prior)

    # create sampler that will run
    sampler = option_utils.sampler_from_cli(opts, likelihood)

    # get distribution to draw from for each walker initial position
    logging.info("Setting walkers initial conditions for varying parameters")

    # check if user wants to use different distributions than prior for setting
    # walkers initial positions
    if len(cp.get_subsections("initial")):
        initial_distributions = inference.read_distributions_from_config(cp,
            section="initial")
    else:
        initial_distributions = inference.read_distributions_from_config(cp,
            section="prior")

    # set the initial positions
    sampler.set_p0(initial_distributions)

    # setup checkpointing
    if opts.checkpoint_interval:

        # determine intervals to run sampler until we save the new samples 
        intervals = [i for i in range(0, opts.niterations, opts.checkpoint_interval)] \
                    + [opts.niterations]

        # determine if there is a small bit at the end
        remainder = opts.niterations % opts.checkpoint_interval
        if remainder:
            intervals += [intervals[-1]+remainder]

    # if not checkpointing then set intervals to run sampler in one call
    else:
        intervals = [0, opts.niterations]

    intervals = numpy.array(intervals)

    # check if user wants to skip the burn in
    if not opts.skip_burn_in:
        logging.info("Burn in")
        sampler.burn_in()
        n_burnin = sampler.burn_in_iterations
        logging.info("Used %i burn in samples" % n_burnin)
        with InferenceFile(opts.output_file, "a") as fp:
            # write the burn in results
            sampler.write_results(fp, max_iterations=opts.niterations+n_burnin)
    else:
        n_burnin = 0


    # increase the intervals to account for the burn in
    intervals += n_burnin

    # loop over number of checkpoints
    for i,start in enumerate(intervals[:-1]):

        end = intervals[i+1]

        # run sampler and set initial values to None so that sampler
        # picks up from where it left off next call
        logging.info("Running sampler for %i to %i out of %i iterations"%(
            start-n_burnin,end-n_burnin,opts.niterations))
        sampler.run(end-start)

        # write new samples
        with InferenceFile(opts.output_file, "a") as fp:
            logging.info("Writing results to file")
            sampler.write_results(fp, max_iterations=opts.niterations+n_burnin)
            # compute the acls and write
            logging.info("Computing acls")
            sampler.write_acls(fp, sampler.compute_acls(fp))
            # compute evidence, if supported
            try:
                lnz, dlnz = sampler.calculate_logevidence(fp)
                logging.info("Saving evidence")
                sampler.write_logevidence(fp, lnz, dlnz)
            except NotImplementedError:
                pass

# exit
logging.info("Done")
