#!/home/travis/build/pycbc-build/environment/bin/python

# Copyright (C) 2016 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Runs a sampler to find the posterior distributions.
"""

import os
import argparse
import logging
import numpy
import pycbc
import pycbc.opt
import pycbc.weave
from pycbc import distributions
from pycbc import transforms
from pycbc import fft
from pycbc import gate
from pycbc import inference
from pycbc import psd
from pycbc import scheme
from pycbc import strain
from pycbc import types
from pycbc.io.inference_hdf import InferenceFile
from pycbc.waveform import generator
from pycbc.inference import option_utils
from pycbc.inference import prior

# command line usage
parser = argparse.ArgumentParser(usage=__file__ + " [--options]",
                                 description=__doc__)

# add data options
parser.add_argument("--instruments", type=str, nargs="+", required=True,
                    help="IFOs, eg. H1 L1.")
option_utils.add_low_frequency_cutoff_opt(parser)
parser.add_argument("--psd-start-time", type=float, default=None,
                    help="Start time to use for PSD estimation if different "
                         "from analysis.")
parser.add_argument("--psd-end-time", type=float, default=None,
                    help="End time to use for PSD estimation if different "
                         "from analysis.")

# add inference options
parser.add_argument("--likelihood-evaluator", required=True,
                    choices=inference.likelihood_evaluators.keys(),
                    help="Evaluator class to use to calculate the likelihood.")
parser.add_argument("--seed", type=int, default=0,
                    help="Seed to use for the random number generator that "
                         "initially distributes the walkers. Default is 0.")
parser.add_argument("--samples-file", default=None,
                    help="Use an iteration from an InferenceFile as the "
                         "initial proposal distribution. The same "
                         "number of walkers and the same [variable_args] "
                         "section in the configuration file should be used. "
                         "The priors must allow encompass the initial "
                         "positions from the InferenceFile being read.")

# add sampler options
option_utils.add_sampler_option_group(parser)

# add config options
option_utils.add_config_opts_to_parser(parser)

# output options
parser.add_argument("--output-file", type=str, required=True,
                    help="Output file path.")
parser.add_argument("--force", action="store_true", default=False,
                    help="If the output-file already exists, overwrite it. "
                         "Otherwise, an OSError is raised.")
parser.add_argument("--save-strain", action="store_true", default=False,
                    help="Save the conditioned strain time series to the "
                         "output file. If gate-overwhitened, this is done "
                         "before all gates have been applied.")
parser.add_argument("--save-stilde", action="store_true", default=False,
                    help="Save the conditioned strain frequency series to "
                         "the output file. This is done after all gates have "
                         "been applied.")
parser.add_argument("--save-psd", action="store_true", default=False,
                    help="Save the psd of each ifo to the output file.")
parser.add_argument("--checkpoint-interval", type=int, default=None,
                    help="Number of iterations to take before saving new "
                         "samples to file.")
parser.add_argument("--checkpoint-fast", action="store_true",
                    help="Do not calculate derived data (eg. ACL or evidence) "
                         "after each checkpoint. Calculate them at the end.")

# verbose option
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Print logging messages.")

# add module pre-defined options
fft.insert_fft_option_group(parser)
pycbc.opt.insert_optimization_option_group(parser)
psd.insert_psd_option_group_multi_ifo(parser)
scheme.insert_processing_option_group(parser)
strain.insert_strain_option_group_multi_ifo(parser)
pycbc.weave.insert_weave_option_group(parser)
gate.add_gate_option_group(parser)

# parse command line
opts = parser.parse_args()

# verify options are sane
fft.verify_fft_options(opts, parser)
pycbc.opt.verify_optimization_options(opts, parser)
#psd.verify_psd_options(opts, parser)
scheme.verify_processing_options(opts, parser)
#strain.verify_strain_options(opts, parser)
pycbc.weave.verify_weave_options(opts, parser)

# check for the output file
if os.path.exists(opts.output_file) and not opts.force:
    raise OSError("output-file already exists; use --force if you wish to "
                  "overwrite it.")

# setup log
pycbc.init_logging(opts.verbose)

# set seed
numpy.random.seed(opts.seed)
logging.info("Using seed %i", opts.seed)

# we'll silence numpy warnings since they are benign and make for confusing
# logging output
numpy.seterr(divide='ignore', invalid='ignore')

# get scheme
ctx = scheme.from_cli(opts)
fft.from_cli(opts)

# get the data and psd
logging.info("Loading data")
strain_dict, stilde_dict, psd_dict = option_utils.data_from_cli(opts)
low_frequency_cutoff_dict = option_utils.low_frequency_cutoff_from_cli(opts)

with ctx:

    # read configuration file
    cp = option_utils.config_parser_from_cli(opts)

    # get the vairable and static arguments from the config file
    variable_args, static_args, constraints = \
                                         option_utils.read_args_from_config(cp)

    # get prior distribution for each variable parameter
    logging.info("Setting up priors for each parameter")
    dists = distributions.read_distributions_from_config(cp, "prior")

    # construct class that will return the prior
    prior_eval = prior.PriorEvaluator(variable_args, *dists,
                                      **{"constraints" : constraints})

    # get sampling transformations
    if cp.has_section('sampling_parameters'):
        sampling_parameters, replace_parameters = \
            option_utils.read_sampling_args_from_config(cp)
        sampling_transforms = transforms.read_transforms_from_config(cp,
            'sampling_transforms')
        logging.info("Sampling in {} in place of {}".format(
            ', '.join(sampling_parameters), ', '.join(replace_parameters)))
    else:
        sampling_parameters = replace_parameters = sampling_transforms = None

    # select generator that will generate waveform for a single IFO
    # for likelihood evaluator
    logging.info("Setting up sampler")
    generator_function = generator.select_waveform_generator(
                                                    static_args["approximant"])

    # construct class that will generate waveforms
    generated_waveform = generator.FDomainDetFrameGenerator(
                       generator_function, epoch=stilde_dict.values()[0].epoch,
                       variable_args=variable_args, detectors=opts.instruments,
                       delta_f=stilde_dict.values()[0].delta_f,
                       delta_t=strain_dict.values()[0].delta_t,
                       **static_args)

    # construct class that will return the natural logarithm of likelihood
    likelihood = inference.likelihood_evaluators[opts.likelihood_evaluator](
                        generated_waveform, stilde_dict,
                        low_frequency_cutoff_dict.values()[0],
                        psds=psd_dict, prior=prior_eval,
                        sampling_parameters=sampling_parameters,
                        replace_parameters=replace_parameters,
                        sampling_transforms=sampling_transforms)

    # create sampler that will run
    sampler = option_utils.sampler_from_cli(opts, likelihood)

    # save command line and data, if desired
    with InferenceFile(opts.output_file, "w") as fp:
        logging.info("Writing command line and data to output file")
        fp.write_command_line()
        fp.write_data(strain_dict=strain_dict if opts.save_strain else None,
                      stilde_dict=stilde_dict if opts.save_stilde else None,
                      psd_dict=psd_dict if opts.save_psd else None,
                      low_frequency_cutoff_dict=low_frequency_cutoff_dict)

    # set the walkers initial positions from a pre-existing InferenceFile
    # or a specific initial distribution listed in the configuration file
    # or else use the prior distributions to set initial positions
    logging.info("Setting walkers initial conditions for varying parameters")
    if opts.samples_file:
        with InferenceFile(opts.samples_file, "r") as fp:
            iteration = fp.niterations - 1
            logging.info("Initial positions taken from iteration %i in %s",
                         iteration, opts.samples_file)
            p0 = fp.read_samples(sampler.variable_args, iteration=iteration)
        init_prior = None
    elif len(cp.get_subsections("initial")):
        initial_dists = distributions.read_distributions_from_config(
                                                         cp, section="initial")
        init_prior = prior.PriorEvaluator(variable_args, *initial_dists,
                                          **{"constraints" : constraints})
        p0 = None
    else:
        init_prior = None
        p0 = None
    sampler.set_p0(samples=p0, prior=init_prior)

    # setup checkpointing
    if opts.checkpoint_interval:

        # determine intervals to run sampler until we save the new samples 
        intervals = [i for i in range(0, opts.niterations,
                                      opts.checkpoint_interval)]

        # determine if there is a small bit at the end
        remainder = opts.niterations % opts.checkpoint_interval
        if remainder:
            intervals += [intervals[-1] + remainder]
        else:
            intervals += [opts.niterations]

    # if not checkpointing then set intervals to run sampler in one call
    else:
        intervals = [0, opts.niterations]

    intervals = numpy.array(intervals)

    # check if user wants to burn in
    if not opts.skip_burn_in:
        logging.info("Burn in")
        sampler.burn_in()
        n_burnin = sampler.burn_in_iterations
        logging.info("Used %i burn in samples", n_burnin)
        with InferenceFile(opts.output_file, "a") as fp:
            # write the burn in results
            sampler.write_results(fp,
                                  max_iterations=opts.niterations + n_burnin)
    else:
        n_burnin = 0


    # increase the intervals to account for the burn in
    intervals += n_burnin

    # if getting samples from file then put sampler and random number generator
    # back in its former state
    if opts.samples_file:
        with InferenceFile(opts.samples_file, "r") as fp:
            sampler.set_state_from_file(fp)
            numpy.random.set_state(fp.read_random_state())

    # loop over number of checkpoints
    for i, start in enumerate(intervals[:-1]):

        end = intervals[i + 1]

        # run sampler and set initial values to None so that sampler
        # picks up from where it left off next call
        logging.info("Running sampler for %i to %i out of %i iterations",
                     start - n_burnin, end - n_burnin, opts.niterations)
        sampler.run(end - start)

        # write new samples
        with InferenceFile(opts.output_file, "a") as fp:

            logging.info("Writing results to file")
            sampler.write_results(fp, start_iteration=start,
                                  end_iteration=end,
                                  max_iterations=opts.niterations + n_burnin)
            logging.info("Writing state to file")
            fp.write_random_state()

            # include a try-except statement because all samples might not
            # have this implemented
            try:
                sampler.write_state(fp)
            except NotImplementedError:
                logging.warn("Writing state is not implemented for %s",
                             sampler.name)

            if not opts.checkpoint_fast or end == opts.niterations:

                # compute the acls and write
                logging.info("Computing acls")
                sampler.write_acls(fp, sampler.compute_acls(fp))

                # compute evidence, if supported
                try:
                    lnz, dlnz = sampler.calculate_logevidence(fp)
                    logging.info("Saving evidence")
                    sampler.write_logevidence(fp, lnz, dlnz)
                except NotImplementedError:
                    pass

            # clear the in-memory chain to save memory
            logging.info("Clearing chain")
            sampler.clear_chain()

# exit
logging.info("Done")
