#!/home/travis/build/pycbc-build/environment/bin/python

# Copyright (C) 2016 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import argparse
import corner
import itertools
import logging
import matplotlib as mpl; mpl.use("Agg")
import matplotlib.pyplot as plt
import numpy
import sys
from pycbc import distributions
from pycbc import inference
from pycbc import results
from pycbc.workflow import WorkflowConfigParser

def cartesian(arrays):
    """ Returns a cartesian product from a list of iterables.
    """
    return numpy.array([numpy.array(element) for element in itertools.product(*arrays)])

# command line usage
parser = argparse.ArgumentParser(usage="pycbc_inference_plot_prior [--options]",
    description="Plots prior distributions.")

# add input options
parser.add_argument("--config-files", type=str, nargs="+", required=True,
    help="A file parsable by pycbc.workflow.WorkflowConfigParser.")
parser.add_argument("--sections", type=str, nargs="+", default=["prior"],
    help="Name of section plus subsection with distribution configurations, eg. prior-mass1.")

# add prior options
parser.add_argument("--bins", type=int, default=20,
    help="Number of points to grid a parameter. Need at least 20 bins to fill in the plot.")

# add output options
parser.add_argument("--output-file", type=str, required=True,
    help="Path to output plot.")

# verbose option
parser.add_argument("--verbose", action="store_true", default=False,
    help="")

# parse the command line
opts = parser.parse_args()

# setup log
if opts.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format="%(asctime)s : %(message)s", level=log_level)

# read configuration file
logging.info("Reading configuration files")
cp = WorkflowConfigParser(opts.config_files)

# get prior distribution for each variable parameter
# parse command line values for section and subsection
# if only section then look for subsections
# and add distributions to list
logging.info("Constructing prior")
variable_args = []
dists = []
for sec in opts.sections:
    section = sec.split("-")[0]
    subsec = sec.split("-")[1:]
    if len(subsec):
        subsections = ["-".join(subsec)]
    else:
        subsections = cp.get_subsections(section)
    for subsection in subsections:
        name = cp.get_opt_tag(section, "name", subsection)
        dist = distributions.distribs[name].from_config(
                                            cp, section, subsection)
        variable_args += dist.params
        dists.append(dist)
variable_args = sorted(variable_args)
ndim = len(variable_args)

# construct class that will return draws from the prior
prior = inference.PriorEvaluator(variable_args, *dists)

# get all points in space to calculate PDF
logging.info("Getting grid of points")
vals = numpy.zeros(shape=(ndim,opts.bins))
bounds = [{}] * ndim
for dist in dists:
    for param in dist.params:
        idx = variable_args.index(param)
        step = float(dist.bounds[param][1]-dist.bounds[param][0]) / opts.bins
        vals[idx,:] = numpy.arange(dist.bounds[param][0],dist.bounds[param][1],step)
        bounds[idx] = dist.bounds
pts = cartesian(vals)

# evaulate PDF between the bounds
logging.info("Calculating PDF")
pdf = []
for pt in pts:
    pt_dict = dict([(param,pt[j]) for j,param in enumerate(variable_args)])
    pdf.append(sum([dist.pdf(**pt_dict) for dist in dists]))
pdf = numpy.array(pdf)

# check if only one parameter to plot PDF
logging.info("Plotting") 
if len(variable_args) == 1:
    x = vals[0,:]
    xmax = x.max()
    xmin = x.min()
    pad = 0.1*(xmax-xmin)
    fig = plt.figure()
    plt.plot(x, pdf, "k", label="Prior")
    plt.xlim(xmin-pad, xmax+pad)
    plt.ylabel("Probability Density Function")
    plt.xlabel(variable_args[0])
    plt.legend()

# else make corner plot of all PDF
else:
    fig = corner.corner(pts, weights=pdf, labels=variable_args,
                       plot_contours=False, plot_datapoints=False)

    # remove the 1-D histograms
    delaxs = fig.axes[::ndim+1]
    [fig.delaxes(ax) for ax in delaxs]

    # adjust size of remaining plots to cover entire canvas
    fig.subplots_adjust(top=1+1.0/(ndim))
    fig.subplots_adjust(right=1+1.0/(ndim))
    for ax in fig.axes:
        ax.yaxis.get_label().set_position((-0.2,0.5))
        ax.xaxis.get_label().set_position((0.5,-0.2))

# save figure with meta-data
caption_kwargs = {
    "parameters" : ", ".join([param for param in variable_args])
}
caption = """This plot shows the probability density function (PDF) from the 
prior distributions."""
title = """Prior Distributions for {parameters}""".format(**caption_kwargs)
results.save_fig_with_metadata(fig, opts.output_file,
                               cmd=" ".join(sys.argv),
                               title=title,
                               caption=caption)
plt.close()

# exit
logging.info("Done")
