#!/home/travis/build/pycbc-build/environment/bin/python

# Copyright (C) 2016 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Runs a sampler to find the posterior distributions.
"""

import os
import argparse
import logging
import numpy
import pycbc
import pycbc.opt
import pycbc.weave
from pycbc import distributions
from pycbc import transforms
from pycbc import fft
from pycbc import gate
from pycbc import inference
from pycbc import psd
from pycbc import scheme
from pycbc import strain
from pycbc import types
from pycbc.io.inference_hdf import InferenceFile
from pycbc.waveform import generator
from pycbc.inference import option_utils
from pycbc.inference import prior
from pycbc.inference import burn_in

# command line usage
parser = argparse.ArgumentParser(usage=__file__ + " [--options]",
                                 description=__doc__)

# add data options
parser.add_argument("--instruments", type=str, nargs="+", required=True,
                    help="IFOs, eg. H1 L1.")
option_utils.add_low_frequency_cutoff_opt(parser)
parser.add_argument("--psd-start-time", type=float, default=None,
                    help="Start time to use for PSD estimation if different "
                         "from analysis.")
parser.add_argument("--psd-end-time", type=float, default=None,
                    help="End time to use for PSD estimation if different "
                         "from analysis.")

# add inference options
parser.add_argument("--likelihood-evaluator", required=True,
                    choices=inference.likelihood_evaluators.keys(),
                    help="Evaluator class to use to calculate the likelihood.")
parser.add_argument("--seed", type=int, default=0,
                    help="Seed to use for the random number generator that "
                         "initially distributes the walkers. Default is 0.")
parser.add_argument("--samples-file", default=None,
                    help="Use an iteration from an InferenceFile as the "
                         "initial proposal distribution. The same "
                         "number of walkers and the same [variable_args] "
                         "section in the configuration file should be used. "
                         "The priors must allow encompass the initial "
                         "positions from the InferenceFile being read.")

# add sampler options
option_utils.add_sampler_option_group(parser)

# add config options
option_utils.add_config_opts_to_parser(parser)

# output options
parser.add_argument("--output-file", type=str, required=True,
                    help="Output file path.")
parser.add_argument("--force", action="store_true", default=False,
                    help="If the output-file already exists, overwrite it. "
                         "Otherwise, an OSError is raised.")
parser.add_argument("--save-strain", action="store_true", default=False,
                    help="Save the conditioned strain time series to the "
                         "output file. If gate-overwhitened, this is done "
                         "before all gates have been applied.")
parser.add_argument("--save-stilde", action="store_true", default=False,
                    help="Save the conditioned strain frequency series to "
                         "the output file. This is done after all gates have "
                         "been applied.")
parser.add_argument("--save-psd", action="store_true", default=False,
                    help="Save the psd of each ifo to the output file.")
parser.add_argument("--checkpoint-interval", type=int, default=None,
                    help="Number of iterations to take before saving new "
                         "samples to file, calculating ACL, and updating "
                         "burn-in estimate.")
parser.add_argument("--checkpoint-fast", action="store_true",
                    help="Do not calculate ACL after each checkpoint, only at "
                         "the end. Not applicable if n-independent-samples "
                         "have been specified.")

# verbose option
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Print logging messages.")

# add module pre-defined options
fft.insert_fft_option_group(parser)
pycbc.opt.insert_optimization_option_group(parser)
psd.insert_psd_option_group_multi_ifo(parser)
scheme.insert_processing_option_group(parser)
strain.insert_strain_option_group_multi_ifo(parser)
pycbc.weave.insert_weave_option_group(parser)
gate.add_gate_option_group(parser)

# parse command line
opts = parser.parse_args()

# verify options are sane
fft.verify_fft_options(opts, parser)
pycbc.opt.verify_optimization_options(opts, parser)
#psd.verify_psd_options(opts, parser)
scheme.verify_processing_options(opts, parser)
#strain.verify_strain_options(opts, parser)
pycbc.weave.verify_weave_options(opts, parser)

# check for the output file
if os.path.exists(opts.output_file) and not opts.force:
    raise OSError("output-file already exists; use --force if you wish to "
                  "overwrite it.")

# check for how many iterations to run
max_iterations = opts.niterations
if opts.niterations is not None and opts.n_independent_samples is not None:
    raise ValueError("Must specify either niterations or n-independent-"
                     "samples, not both")
elif opts.niterations is not None:
    get_nsamples = opts.niterations
elif opts.n_independent_samples is not None:
    if opts.checkpoint_interval is None:
        raise ValueError("n-independent-samples requires a checkpoint-"
                         "interval; see help")
    get_nsamples = opts.n_independent_samples
else:
    raise ValueError("Must specify niterations or n-independent-samples; "
                     "see --help")

# setup log
pycbc.init_logging(opts.verbose)

# warn user about options that will be removed
if opts.skip_burn_in:
    logging.warning("DEPRECATION WARNING: Option --skip-burn-in has been "
                    "deprecated and does nothing. "
                    "If you wish to use no burn in, simply do not "
                    "provide a burn-in-function. This option will be removed "
                    "in future versions.")

# set seed
numpy.random.seed(opts.seed)
logging.info("Using seed %i", opts.seed)

# we'll silence numpy warnings since they are benign and make for confusing
# logging output
numpy.seterr(divide='ignore', invalid='ignore')

# get scheme
ctx = scheme.from_cli(opts)
fft.from_cli(opts)

# get the data and psd
logging.info("Loading data")
strain_dict, stilde_dict, psd_dict = option_utils.data_from_cli(opts)
low_frequency_cutoff_dict = option_utils.low_frequency_cutoff_from_cli(opts)

with ctx:

    # read configuration file
    cp = option_utils.config_parser_from_cli(opts)

    # get the vairable and static arguments from the config file
    variable_args, static_args, constraints = \
                                         option_utils.read_args_from_config(cp)

    # get prior distribution for each variable parameter
    logging.info("Setting up priors for each parameter")
    dists = distributions.read_distributions_from_config(cp, "prior")

    # construct class that will return the prior
    prior_eval = prior.PriorEvaluator(variable_args, *dists,
                                      **{"constraints" : constraints})

    # get sampling transformations
    if cp.has_section('sampling_parameters'):
        sampling_parameters, replace_parameters = \
            option_utils.read_sampling_args_from_config(cp)
        sampling_transforms = transforms.read_transforms_from_config(cp,
            'sampling_transforms')
        logging.info("Sampling in {} in place of {}".format(
            ', '.join(sampling_parameters), ', '.join(replace_parameters)))
    else:
        sampling_parameters = replace_parameters = sampling_transforms = None

    # select generator that will generate waveform for a single IFO
    # for likelihood evaluator
    logging.info("Setting up sampler")
    generator_function = generator.select_waveform_generator(
                                                    static_args["approximant"])

    # construct class that will generate waveforms
    generated_waveform = generator.FDomainDetFrameGenerator(
                       generator_function, epoch=stilde_dict.values()[0].epoch,
                       variable_args=variable_args, detectors=opts.instruments,
                       delta_f=stilde_dict.values()[0].delta_f,
                       delta_t=strain_dict.values()[0].delta_t,
                       **static_args)

    # construct class that will return the natural logarithm of likelihood
    likelihood = inference.likelihood_evaluators[opts.likelihood_evaluator](
                        generated_waveform, stilde_dict,
                        low_frequency_cutoff_dict.values()[0],
                        psds=psd_dict, prior=prior_eval,
                        sampling_parameters=sampling_parameters,
                        replace_parameters=replace_parameters,
                        sampling_transforms=sampling_transforms)

    burn_in_eval = burn_in.BurnIn(opts.burn_in_function,
                                min_iterations=opts.min_burn_in)

    # create sampler that will run
    sampler = option_utils.sampler_from_cli(opts, likelihood)

    # save command line and data, if desired
    with InferenceFile(opts.output_file, "w") as fp:
        logging.info("Writing command line and data to output file")
        fp.write_command_line()
        fp.write_data(strain_dict=strain_dict if opts.save_strain else None,
                      stilde_dict=stilde_dict if opts.save_stilde else None,
                      psd_dict=psd_dict if opts.save_psd else None,
                      low_frequency_cutoff_dict=low_frequency_cutoff_dict)

    # set the walkers initial positions from a pre-existing InferenceFile
    # or a specific initial distribution listed in the configuration file
    # or else use the prior distributions to set initial positions
    logging.info("Setting walkers initial conditions for varying parameters")
    if opts.samples_file:
        with InferenceFile(opts.samples_file, "r") as fp:
            iteration = fp.niterations - 1
            logging.info("Initial positions taken from iteration %i in %s",
                         iteration, opts.samples_file)
            p0 = fp.read_samples(sampler.variable_args, iteration=iteration)
        init_prior = None
    elif len(cp.get_subsections("initial")):
        initial_dists = distributions.read_distributions_from_config(
                                                         cp, section="initial")
        init_prior = prior.PriorEvaluator(variable_args, *initial_dists,
                                          **{"constraints" : constraints})
        p0 = None
    else:
        init_prior = None
        p0 = None
    sampler.set_p0(samples=p0, prior=init_prior)

    # if getting samples from file then put sampler and random number generator
    # back in its former state
    if opts.samples_file:
        with InferenceFile(opts.samples_file, "r") as fp:
            try:
                sampler.set_state_from_file(fp)
                numpy.random.set_state(fp.read_random_state())
            except NotImplementedError:
                logging.warning("Sampler does not support setting a random "
                                "state. Just setting initial positions.")

    # run sampler's burn in if it is in the list of burn in functions
    n_sampler_burn_in = 0
    if "use_sampler" in burn_in_eval.burn_in_functions:
        # remove the sampler's burn in so we don't run more than once
        burn_in_eval.burn_in_functions.pop("use_sampler")
        with InferenceFile(opts.output_file, "a") as fp:
            logging.info("Running sampler's burn in function")
            burnidx = burn_in.use_sampler(sampler, fp)
            sampler.write_burn_in_iterations(fp, burnidx)
            # write the burn in results
            n_sampler_burn_in = burnidx.max()
            if max_iterations is not None:
                max_iterations += n_sampler_burn_in
            logging.info("Writing burn in samples to file")
            sampler.write_results(fp, max_iterations=max_iterations)
            try:
                sampler.write_state(fp)
            except NotImplementedError:
                logging.warn("Writing state is not implemented for %s",
                             sampler.name)
            # increase the number of iterations

    # run sampler until we have the desired number of samples
    nsamples = 0
    interval = opts.checkpoint_interval
    if interval is None:
        interval = get_nsamples
    # start keeps track of the number of iterations we have advances (which
    # may not be the same as the number of samples, if we are using
    # n-independent-samples)
    start = n_sampler_burn_in

    while nsamples < get_nsamples:

        end = start + interval

        # adjust the interval if we would go past the number of iterations
        if opts.n_independent_samples is None and end > get_nsamples:
            interval = get_nsamples - start
            end = start + interval

        # run sampler and set initial values to None so that sampler
        # picks up from where it left off next call
        logging.info("Running sampler for {} to {} iterations".format(start,
                                                                      end))
        sampler.run(interval)

        # write new samples
        with InferenceFile(opts.output_file, "a") as fp:

            logging.info("Writing results to file")
            sampler.write_results(fp, start_iteration=start,
                                  end_iteration=end,
                                  max_iterations=max_iterations)
            logging.info("Writing state to file")
            fp.write_random_state()

            logging.info("Updating burn in")
            burn_in_eval.update(sampler, fp)

            # include a try-except statement because all samples might not
            # have this implemented
            try:
                sampler.write_state(fp)
            except NotImplementedError:
                logging.warn("Writing state is not implemented for %s",
                             sampler.name)

            # compute the acls and write
            if opts.n_independent_samples is not None or end >= get_nsamples \
                    or not opts.checkpoint_fast:
                logging.info("Computing acls")
                sampler.write_acls(fp, sampler.compute_acls(fp))

            # update nsamples for next loop
            if opts.n_independent_samples is not None:
                with InferenceFile(opts.output_file, 'r') as fp:
                    nsamples = fp.n_independent_samples
                logging.info("Have {} independent samples".format(nsamples))
            else:
                nsamples += interval

            # clear the in-memory chain to save memory
            logging.info("Clearing chain")
            sampler.clear_chain()

            start = end

    # compute evidence, if supported
    with InferenceFile(opts.output_file, 'a') as fp:
        try:
            lnz, dlnz = sampler.calculate_logevidence(fp)
            logging.info("Saving evidence")
            sampler.write_logevidence(fp, lnz, dlnz)
        except NotImplementedError:
            pass


# exit
logging.info("Done")
