#! /usr/bin/env python

# Copyright (C) 2016 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import argparse
import h5py
import logging
import matplotlib as mpl; mpl.use("Agg")
import matplotlib.pyplot as plt
import numpy
import sys
from pycbc import results
from pycbc.filter import autocorrelation
from pycbc.inference import option_utils

# command line usage
parser = argparse.ArgumentParser(
    description="Plots autocorrelation function from inference samples.")

# verbose option
parser.add_argument("--verbose", action="store_true", default=False,
    help="Print logging info.")
# output plot options
parser.add_argument("--output-file", type=str, required=True,
    help="Path to output plot.")
# add plotting options
parser.add_argument("--ymin", type=float,
    help="Minimum value to plot on y-axis.")
parser.add_argument("--ymax", type=float,
    help="Maximum value to plot on y-axis.")
# add results group
option_utils.add_inference_results_option_group(parser)

# parse the command line
opts = parser.parse_args()

# setup log
if opts.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format="%(asctime)s : %(message)s", level=log_level)

# load the results
fp, parameters, labels, _ = option_utils.results_from_cli(opts,
    load_samples=False)

# calculate autocorrelation function for each walker
logging.info("Calculating autocorrelation functions")
all_acfs = []
for param_name in parameters:

    # loop over walkers and save an autocorrelation function
    # for each walker
    param_acfs = []
    for i in range(fp.nwalkers):

        # get all the past samples for this walker
        y = fp.read_samples(param_name, walkers=i, thin_start=opts.thin_start,
                           thin_interval=opts.thin_interval)

        # autocorrelate and add to list
        param_acfs.append( autocorrelation.calculate_acf(y[param_name]) )

    # add list with parameters ACF to list of all ACFs
    all_acfs.append(param_acfs)

# plot autocorrelation
logging.info("Plotting autocorrelation functions")
fig = plt.figure()
for param_acfs in all_acfs:
    for acf in param_acfs:
        plt.plot(range(len(acf)), acf, alpha=0.25)
plt.xlabel("Iteration")
plt.ylabel(r'Autocorrelation Function for %s'%', '.join(labels))

# format plot
if opts.ymin:
    plt.ylim(ymin=opts.ymin)
if opts.ymax:
    plt.ylim(ymax=opts.ymax)

# save figure with meta-data
caption_kwargs = {
    "parameters" : ", ".join(labels),
}
caption = """Autocorrelation function (ACF) from all the walker chains for the
parameters. Each line is an ACF for a chain of walker samples."""
title = "Autocorrelation Function for {parameters}".format(**caption_kwargs)
results.save_fig_with_metadata(fig, opts.output_file,
                               cmd=" ".join(sys.argv),
                               title=title,
                               caption=caption)
plt.close()

# exit
fp.close()
logging.info("Done")
