#!/usr/bin/env python

# Copyright (C) 2016 Miriam Cabero Mueller, Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#

import argparse
import itertools
import logging
import sys

import numpy

from matplotlib import (patches, use)

import pycbc
import pycbc.version
from pycbc.results import metadata

from pycbc import __version__
from pycbc import conversions
from pycbc.workflow import WorkflowConfigParser
from pycbc.inference import (option_utils, io)
from pycbc import transforms

from pycbc.results.scatter_histograms import create_multidim_plot

use('agg')

# add options to command line
parser = io.ResultsArgumentParser()
# program-specific
parser.add_argument("--version", action="version", version=__version__,
                    help="Prints version information.")
parser.add_argument("--output-file", type=str, required=True,
                    help="Output plot path.")
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Be verbose")
parser.add_argument("--plot-prior", nargs="+", type=str,
                    help="Plot the prior on the 1D marginal plots using the "
                         "given config file(s).")
parser.add_argument("--prior-nsamples", type=int, default=10000,
                    help="The number of samples to use for plotting the "
                         "prior. Default is 10000.")
# add options for what plots to create
option_utils.add_plot_posterior_option_group(parser)
# scatter configuration
option_utils.add_scatter_option_group(parser)
# density configuration
option_utils.add_density_option_group(parser)
parser.add_argument('--dpi', type=int, default=200,
                    help="Set the DPI of the plot. Default is 200.")

# parse command line
opts = parser.parse_args()

# set logging
pycbc.init_logging(opts.verbose)

# load the samples
fps, parameters, labels, samples = io.results_from_cli(opts)

# typecast to list so the input files can be iterated over
fps = fps if isinstance(fps, list) else [fps]
samples = samples if isinstance(samples, list) else [samples]

# if a z-arg is specified, load samples for it
if opts.z_arg is not None:
    logging.info("Getting samples for colorbar")
    z_arg = 'loglikelihood' if opts.z_arg == 'snr' else opts.z_arg
    zlbl = opts.z_arg_labels[opts.z_arg]
    zvals = []
    for fp in fps:
        zsamples = fp.samples_from_cli(opts, parameters=z_arg)
        if opts.z_arg == 'snr':
            loglr = zsamples[z_arg] - zsamples.lognl
            zsamples[z_arg] = conversions.snr_from_loglr(loglr)
        zvals.append(zsamples[z_arg])
else:
    zvals = None
    zlbl = None

# closet the files, we don't need them anymore
for fp in fps:
    fp.close()

# if no plotting options selected, then the default options are based
# on the number of parameters
plot_options = [opts.plot_marginal, opts.plot_scatter, opts.plot_density]
if not numpy.any(plot_options):
    if len(parameters) == 1:
        opts.plot_marginal = True
    else:
        opts.plot_scatter = True
        # FIXME: right now if there are two parameters it wants
        # both plot_scatter and plot_marginal. One should have the option
        # of give only plot_scatter and that should be the default for
        # two or more parameters
        opts.plot_marginal = True

if opts.plot_prior is not None:
    # check that we're plotting 1D marginals
    if not opts.plot_marginal:
        raise ValueError("prior may only be plotted on 1D marginal plot; "
                         "either turn on --plot-marginal, or turn off "
                         "--plot-prior")
    logging.info("Loading prior")
    cp = WorkflowConfigParser(opts.plot_prior)
    prior = option_utils.prior_from_config(cp)
    logging.info("Drawing samples from prior")
    prior_samples = prior.rvs(opts.prior_nsamples)
    _, ts = transforms.get_common_cbc_transforms(
        opts.parameters, prior_samples.fieldnames)
    # add parameters not included in file
    prior_samples = transforms.apply_transforms(prior_samples, ts)

# get minimum and maximum ranges for each parameter from command line
mins, maxs = option_utils.plot_ranges_from_cli(opts)

# add any missing parameters
for p in parameters:
    if p not in mins:
        mins[p] = numpy.array([s[p].min() for s in samples]).min()
    if p not in maxs:
        maxs[p] = numpy.array([s[p].max() for s in samples]).max()

# get injection values if desired
expected_parameters = {}
if opts.plot_injection_parameters:
    injections = io.injections_from_cli(opts)
    for p in parameters:
        # check that all of the injections are the same
        try:
            vals = injections[p]
        except NameError:
            # injection doesn't have this parameter, skip
            logging.warn("Could not find injection parameter {}".format(p))
            continue
        unique_vals = numpy.unique(vals)
        if unique_vals.size != 1:
            raise ValueError("More than one injection found! To use "
                "plot-injection-parameters, there must be a single unique "
                "injection in all input files. Use the expected-parameters "
                "option to specify an expected parameter instead.")
        # passed: use the value for the expected
        expected_parameters[p] = unique_vals[0]

# get expected parameter values from command line
expected_parameters.update(option_utils.expected_parameters_from_cli(opts))

# assign some colors
colors = itertools.cycle(["black"] + ["C{}".format(i) for i in range(10)])

# plot each input file
logging.info("Plotting")
hist_colors = []
for (i, s) in enumerate(samples):

    # on first iteration create figure otherwise update old figure
    if i == 0:
        fig = None
        axis_dict = None

    # loop over some contour colors
    contour_color = next(colors) if not opts.contour_color \
        else opts.contour_color

    # make histograms filled if only one input file to plot
    fill_color = "gray" if len(opts.input_file) == 1 else None

    # make histogram black lines if only one
    hist_color = "black" if len(opts.input_file) == 1 else contour_color
    hist_colors.append(hist_color)

    # pick a new color for each input file
    linecolor = "navy" if len(opts.input_file) == 1 else contour_color

    # plot
    fig, axis_dict = create_multidim_plot(
                    parameters, s, labels=labels, fig=fig, axis_dict=axis_dict,
                    plot_marginal=opts.plot_marginal,
                    marginal_percentiles=opts.marginal_percentiles,
                    plot_scatter=opts.plot_scatter,
                    zvals=zvals[i] if zvals is not None else None,
                    show_colorbar=opts.z_arg is not None,
                    cbar_label=zlbl,
                    vmin=opts.vmin, vmax=opts.vmax,
                    scatter_cmap=opts.scatter_cmap,
                    plot_density=opts.plot_density,
                    plot_contours=opts.plot_contours,
                    contour_percentiles=opts.contour_percentiles,
                    density_cmap=opts.density_cmap,
                    contour_color=contour_color,
                    hist_color=hist_color,
                    line_color=linecolor,
                    fill_color=fill_color,
                    use_kombine=opts.use_kombine_kde,
                    mins=mins, maxs=maxs,
                    expected_parameters=expected_parameters,
                    expected_parameters_color=opts.expected_parameters_color)

# plot the prior
if opts.plot_prior:
    if len(opts.input_file) > 1:
        hist_color = next(colors)
    fig, axis_dict = create_multidim_plot(
        parameters, prior_samples, fig=fig, axis_dict=axis_dict,
        labels=labels, plot_marginal=True, marginal_percentiles=[],
        plot_scatter=False, plot_density=False, plot_contours=False,
        fill_color=None,
        marginal_title=False, marginal_linestyle=':',
        hist_color=hist_color,
        mins=mins, maxs=maxs)

# add legend to upper right for input files
if len(opts.input_file) > 1:
    handles = []
    labels = []
    for color, fn in zip(hist_colors, opts.input_file):
        label = opts.input_file_labels[fn]
        handles.append(patches.Patch(color=color, label=label))
        labels.append(label)
    fig.legend(loc="upper right", handles=handles,
               labels=labels)

# set DPI
fig.set_dpi(opts.dpi)

# save
metadata.save_fig_with_metadata(
                 fig, opts.output_file, {},
                 cmd=" ".join(sys.argv),
                 title="Posteriors",
                 caption="Posterior probability density functions.")

# finish
logging.info("Done")
