#!/usr/bin/python

# Copyright 2019 Gareth S. Davies
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.


import sys, h5py
import argparse, logging

from matplotlib import use
use('Agg')
from matplotlib import pyplot as plt
import numpy as np

from pycbc import events, bin_utils, results
from pycbc.io import HFile, SingleDetTriggers
from pycbc.events import triggers as trigs
from pycbc.events import trigger_fits as trstats
from pycbc.events import stat as pystat
from pycbc.types.optparse import MultiDetOptionAction
import pycbc.version

parser = argparse.ArgumentParser(usage="",
    description="Plot histograms of triggers split over various parameters")
pycbc.add_common_pycbc_options(parser)
parser.add_argument("--trigger-file", required=True,
                    help="Input hdf5 file containing single triggers. "
                         "Required")
parser.add_argument("--bank-file", default=None, required=True,
                    help="hdf file containing template parameters. Required")
parser.add_argument('--output-file', required=True,
                    help="Output image file. Required")
parser.add_argument('--bin-param', default='template_duration',
                    help="Parameter for binning within plots, default "
                         "'template_duration'")
parser.add_argument('--bin-spacing', choices=['linear', 'log'], default='log',
                    help="How to space bin-param bin edges. "
                         "Choices=[linear, log], default log")
parser.add_argument('--num-bins', type=int, default=6,
                    help="Number of bins over which to split bin-param, default 6")
parser.add_argument('--max-bin-param', type=float, default=None,
                    help="Maximum allowed value of bin-param")
parser.add_argument('--split-param-one', default='eta',
                    help="Parameter for splitting plot grid in y-direction, "
                         "default 'eta'")
parser.add_argument('--split-param-two', default='chi_eff',
                    help="Parameter for splitting plot grid in x-direction, "
                         "default 'chi_eff'")
parser.add_argument('--split-one-nbins', default=3, type=int,
                    help="Number of split plots over split-param-one, default 3")
parser.add_argument('--split-two-nbins', default=4, type=int,
                    help="Number of split plots over split-param-two, default 4")
parser.add_argument('--split-one-spacing', choices=['linear', 'log'], default='linear',
                    help="How to space split-param-one bin edges. "
                         "Choices=[linear, log], default=linear")
parser.add_argument('--split-two-spacing', choices=['linear', 'log'], default='linear',
                    help="How to space split-param-two bin edges. "
                         "Choices=[linear, log], default=linear")
parser.add_argument('--ifo', help="Detector. Required")
parser.add_argument('--plot-max-x', type=float, default=None,
                    help="Maximum stat value to plot, if not given "
                         "1.05 * largest stat value will be used")
parser.add_argument("--veto-file",
                    help="File(s) in .xml format with veto segments to apply "
                         "to triggers before fitting")
parser.add_argument("--veto-segment-name",
                    help="Name(s) of veto segments to apply. Optional, if not "
                         "given all triggers for the given ifo will be used")
parser.add_argument("--gating-veto-windows", nargs='+',
                    action=MultiDetOptionAction,
                    help="Seconds to be vetoed before and after the central time "
                         "of each gate. Given as detector-values pairs, e.g. "
                         "H1:-1,2.5 L1:-1,2.5 V1:0,0")
parser.add_argument("--stat-fit-threshold", type=float, required=True,
                    help="Only fit triggers with statistic value above this "
                         "threshold. Required")
parser.add_argument("--plot-lower-stat-limit", type=float, required=True,
                    help="Plot triggers down to this value. Setting this too"
                         "low will incur huge memory usage in a full search."
                         "To avoid this, choose 5.5 or larger.")
parser.add_argument("--fit-function",
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit")
parser.add_argument("--prune-number", type=int, default=0,
                    help="Number of loudest events to remove from each split "
                         "histogram, default 0")
parser.add_argument("--prune-window", type=float, default=0.1,
                    help="Time (s) to remove all triggers around a trigger "
                         "which is loudest in each split, default 0.1s")

pystat.insert_statistic_option_group(parser,
    default_ranking_statistic='single_ranking_only')
args = parser.parse_args()

assert(args.stat_fit_threshold >= args.plot_lower_stat_limit)

pycbc.init_logging(args.verbose)

logging.info('Opening trigger file: %s' % args.trigger_file)
trigf = h5py.File(args.trigger_file, 'r')

logging.info('Opening template file: %s' % args.bank_file)
bank = h5py.File(args.bank_file, 'r')
logging.info('setting up template bank parameters')
params = {
        "mass1" : bank['mass1'][:],
        "mass2" : bank['mass2'][:],
        "spin1z": bank['spin1z'][:],
        "spin2z": bank['spin2z'][:]
}
bank.close()

# Only calculate params if needed
usedparams = [args.split_param_two, args.split_param_one, args.bin_param]
extparams = [par for par in ['mchirp', 'eta', 'chi_eff', 'template_duration']
             if par in usedparams]

for ex_p in extparams:
    if ex_p == 'template_duration':
        logging.info('Reading duration from trigger file')
        # List comprehension loops over templates; if a template has no triggers, accessing
        # the 0th entry of its region reference will return zero due to a quirk of h5py.
        params[ex_p] = np.array(
            [
                trigf[args.ifo + '/template_duration'][ref][0]
                for ref in trigf[args.ifo + '/template_duration_template'][:]
            ]
        )
    else:
        logging.info("Calculating " + ex_p + " from template parameters")
        params[ex_p] = trigs.get_param(ex_p, args, params['mass1'],
                                       params['mass2'], params['spin1z'],
                                       params['spin2z'])

# string formats for labels, logging etc.
formats = {
        "mchirp": '{:.2f}',
        "eta": '{:.3f}',
        "chi_eff": '{:.3f}',
        'template_duration': '{:.0f}'
}

logging.info('setting up {}'.format(usedparams))
sp_one_bin_input = (params[args.split_param_one].min(),
                    params[args.split_param_one].max(), args.split_one_nbins)
sp_two_bin_input = (params[args.split_param_two].min(),
                    params[args.split_param_two].max(), args.split_two_nbins)
if args.max_bin_param:
    logging.info(('setting maximum {} value: '
                  + formats[args.bin_param]).format(args.bin_param,
                                                    args.max_bin_param))
    pbin_upper_lim = float(args.max_bin_param)
else:
    pbin_upper_lim = params[args.bin_param].max()

# For templates with no triggers, the duration will be read as zero
if args.bin_param == 'template_duration' and params[args.bin_param].min() == 0:
    logging.warn('WARNING: Some templates do not contain triggers')
    # Use the lowest nonzero template duration as lower limit for bins
    pbin_lower_lim = params[args.bin_param][params[args.bin_param] > 0].min()
else:
    pbin_lower_lim = params[args.bin_param].min()

bb_input = (pbin_lower_lim, pbin_upper_lim, args.num_bins)

logging.info('splitting {} into bins'.format(args.bin_param))
if args.bin_spacing == 'log':
    assert pbin_lower_lim > 0
    pbins = bin_utils.LogarithmicBins(*bb_input)
else:
    pbins = bin_utils.LinearBins(*bb_input)
# Use sentinel value -1 for templates outside range
pind = np.array([pbins[par] if pbin_lower_lim < par < pbin_upper_lim
                 else -1 for par in params[args.bin_param]])

if args.split_one_spacing == 'log':
    assert params[args.split_param_one].min() > 0
    sp_one_bounds = bin_utils.LogarithmicBins(*sp_one_bin_input)
else:
    sp_one_bounds = bin_utils.LinearBins(*sp_one_bin_input)

if args.split_two_spacing == 'log':
    assert params[args.split_param_two].min() > 0
    sp_two_bounds = bin_utils.LogarithmicBins(*sp_two_bin_input)
else:
    sp_two_bounds = bin_utils.LinearBins(*sp_two_bin_input)

logging.info('assigning template ids to different splits')
id_in_bin1 = [[]] * args.split_one_nbins
id_in_bin2 = [[]] * args.split_two_nbins
for i, lower_1, upper_1 in zip(range(args.split_one_nbins),
                               sp_one_bounds.lower(), sp_one_bounds.upper()):
    id_in_bin1[i] = np.intersect1d(np.argwhere(params[args.split_param_one] > lower_1),
                                   np.argwhere(params[args.split_param_one] <= upper_1))

for i, lower_2, upper_2 in zip(range(args.split_two_nbins),
                               sp_two_bounds.lower(), sp_two_bounds.upper()):
    id_in_bin2[i] = np.intersect1d(np.argwhere(params[args.split_param_two] > lower_2),
                                   np.argwhere(params[args.split_param_two] <= upper_2))

logging.info('Getting template boundaries from trigger file')
boundaries = trigf[args.ifo + '/template_boundaries'][:]

trigf.close()

# Setup a data mask to remove any triggers with SNR below threshold
# This works as a pre-filter as SNR is always greater than or equal
# to sngl_ranking, except in the psdvar case, where it could increase.
with HFile(args.trigger_file, 'r') as trig_file:
    n_triggers_orig = trig_file[f'{args.ifo}/snr'].size
    logging.info("Trigger file has %d triggers", n_triggers_orig)
    logging.info('Generating trigger mask')
    if f'{args.ifo}/psd_var_val' in trig_file:
        idx, _, _ = trig_file.select(
            lambda snr, psdvar: snr / psdvar ** 0.5 >= args.plot_lower_stat_limit,
            f'{args.ifo}/snr',
            f'{args.ifo}/psd_var_val',
            return_indices=True
        )
    else:
        # psd_var_val may not have been calculated
        idx, _ = trig_file.select(
            lambda snr: snr >= args.plot_lower_stat_limit,
            f'{args.ifo}/snr',
            return_indices=True
        )
    data_mask = np.zeros(n_triggers_orig, dtype=bool)
    data_mask[idx] = True


logging.info('Calculating single stat values from trigger file')
trigs = SingleDetTriggers(
    args.trigger_file,
    None,
    None,
    None,
    None,
    args.ifo,
    premask=data_mask
)
# This is the direct pointer to the HDF file, used later on
trigf = trigs.trigs_f

stat = trigs.get_ranking(args.sngl_ranking)
time = trigs.end_time


logging.info('Processing template boundaries')
max_boundary_id = np.argmax(boundaries)
sorted_boundary_list = np.sort(boundaries)

# In the two blocks of code that follows we are trying to figure out the index
# ranges in the masked trigger lists corresponding to the "boundaries".
# We will do this by walking over the boundaries in the order they're
# stored in the file, adding in the number of triggers not removed by the
# mask every time.

# First step is to loop over the "boundaries" which gives the start position
# of each block of triggers (corresponding to one template) in the full trigger
# merge file.
# Here we identify the end_idx for the triggers corresponding to each template.
where_idx_end = np.zeros_like(boundaries)
for idx, idx_start in enumerate(boundaries):
    if idx == max_boundary_id:
        where_idx_end[idx] = trigf[args.ifo + '/end_time'].size
    else:
        where_idx_end[idx] = sorted_boundary_list[
            np.argmax(sorted_boundary_list == idx_start) + 1]

# Next we need to map these start/stop indices in the full file, to the start
# stop indices in the masked list of triggers. We do this by figuring out
# how many triggers are in the masked list for each template in the order they
# are stored in the trigger merge file, and keep a running sum.
curr_count = 0
mask_start_idx = np.zeros_like(boundaries)
mask_end_idx = np.zeros_like(boundaries)
for idx_start in sorted_boundary_list:
    boundary_idx = np.argmax(boundaries == idx_start)
    idx_end = where_idx_end[boundary_idx]
    mask_start_idx[boundary_idx] = curr_count
    curr_count += np.sum(trigs.mask[idx_start:idx_end])
    mask_end_idx[boundary_idx] = curr_count


if args.veto_file:
    logging.info('Applying DQ vetoes')
    remove, junk = events.veto.indices_within_segments(
        time,
        [args.veto_file],
        ifo=args.ifo,
        segment_name=args.veto_segment_name
    )
    # Set stat to zero for triggers being vetoed: given that the fit threshold
    # is >0 these will not be fitted or plotted.  Avoids complications from
    # changing the number of triggers, ie changes of template boundary.
    stat[remove] = 0.
    time[remove] = 0.
    logging.info(
        '%d out of %d trigs removed after vetoing with %s from %s',
        remove.size,
        stat.size,
        args.veto_segment_name,
        args.veto_file
    )

if args.gating_veto_windows:
    logging.info('Applying veto to triggers near gates')
    gating_veto = args.gating_veto_windows[args.ifo].split(',')
    gveto_before = float(gating_veto[0])
    gveto_after = float(gating_veto[1])
    if gveto_before > 0 or gveto_after < 0:
        raise ValueError("Gating veto window values must be negative before "
                         "gates and positive after gates.")
    if not (gveto_before == 0 and gveto_after == 0):
        autogate_times = np.unique(trigf[args.ifo + '/gating/auto/time'][:])
        if args.ifo + '/gating/file' in trigf:
            detgate_times = trigf[args.ifo + '/gating/file/time'][:]
        else:
            detgate_times = []
        gate_times = np.concatenate((autogate_times, detgate_times))
        gveto_remove = events.veto.indices_within_times(
            time,
            gate_times + gveto_before,
            gate_times + gveto_after
        )
        stat[gveto_remove] = 0.
        time[gveto_remove] = 0.
        logging.info(
            '%d out of %d trigs removed after vetoing triggers near gates',
            gveto_remove.size,
            stat.size
        )

for x in range(args.split_one_nbins):
    if not args.prune_number:
        logging.info('Not performing any pruning')
        break
    elif args.prune_number and x == 0:
        logging.info('Applying pruning around loudest triggers in each split')
    id_bin1 = id_in_bin1[x]
    for y in range(args.split_two_nbins):
        id_bin2 = id_in_bin2[y]
        # Finding ids of templates in split
        id_in_both = np.intersect1d(id_bin1, id_bin2)
        if len(id_in_both) == 0: continue
        vals_inbin = []
        time_inbin = []
        # getting triggers that are in these templates
        for idx in id_in_both:
            vals_inbin += list(stat[mask_start_idx[idx]:mask_end_idx[idx]])
            time_inbin += list(time[mask_start_idx[idx]:mask_end_idx[idx]])

        vals_inbin = np.array(vals_inbin)
        time_inbin = np.array(time_inbin)

        count_pruned = 0
        logging.info('Pruning in split {}-{} {}-{}'.format(
                     args.split_param_one, x, args.split_param_two, y))
        logging.info('Currently have %d triggers', len(vals_inbin))
        while count_pruned < args.prune_number:
            # Getting loudest statistic value in split
            max_val_arg = vals_inbin.argmax()
            max_statval = vals_inbin[max_val_arg]

            remove = np.nonzero(
                abs(time_inbin[max_val_arg] - time) < args.prune_window
            )[0]
            # Remove from inbin triggers as well in case there
            # are more pruning iterations
            remove_inbin = np.nonzero(
                abs(time_inbin[max_val_arg] - time_inbin) < args.prune_window
            )[0]
            logging.info(
                'Prune %d: removing %d triggers around %.2f, %d in this split',
                count_pruned,
                remove.size,
                time[max_val_arg],
                remove_inbin.size
            )
            # Set pruned triggers' stat values to zero, as above for vetoes
            vals_inbin[remove_inbin] = 0.
            time_inbin[remove_inbin] = 0.
            stat[remove] = 0.
            time[remove] = 0.
            count_pruned += 1

logging.info('Setting up plotting and fitting limit values')
minplot = max(stat[np.nonzero(stat)].min(), args.plot_lower_stat_limit)
min_fit = max(minplot, args.stat_fit_threshold)
max_fit = 1.05 * stat.max()
if args.plot_max_x:
    maxplot = args.plot_max_x
else:
    maxplot = max_fit
fitrange = np.linspace(min_fit, max_fit, 100)

logging.info('Setting up plotting variables')
histcolors = ['r',(1.0,0.6,0),'y','g','c','b','m','k',(0.8,0.25,0),(0.25,0.8,0)]
fig, axes = plt.subplots(args.split_one_nbins, args.split_two_nbins,
                         sharex=True, sharey=True, squeeze=False,
                         figsize=(3 * (args.split_two_nbins + 1),
                                  3 * args.split_one_nbins))

# Setting up overall legend outside the split-up plots
lines = []
labels = []
for i, lower, upper in zip(range(args.num_bins), pbins.lower(), pbins.upper()):
    binlabel = r"%.3g - %.3g" % (lower, upper)
    line, = axes[0,0].plot([0,0], [0,0], linewidth=2,
                     color=histcolors[i], alpha=0.6)
    lines.append(line)
    labels.append(binlabel)

line_fit, = axes[0,0].plot([0,0], [0,0], linestyle='--',
                           color='k', alpha=0.6)
lines.append(line_fit)
labels.append(args.fit_function + ' fit to counts')
fig.legend(lines, labels, labelspacing=0.2,
           loc='upper left', title=args.bin_param)

pidx = []
for i in range(args.num_bins):
    pidx.append([np.argwhere(pind == i)])

logging.info('Starting bin, histogram and plot loop')
maxyval = 0
for x in range(args.split_one_nbins):
    id_bin1 = id_in_bin1[x]
    for y in range(args.split_two_nbins):
        id_bin2 = id_in_bin2[y]
        id_in_both = np.intersect1d(id_bin1, id_bin2)
        logging.info('split {}: {}, {}: {}'.format(args.split_param_one, x,
                                                   args.split_param_two, y))
        ax = axes[x,y]
        for i, lower, upper in zip(range(args.num_bins), pbins.lower(),
                                   pbins.upper()):
            indices_all_conditions = np.intersect1d(pidx[i], id_in_both)
            logging.info('{} split {}-{}'.format(args.bin_param, lower, upper))
            if len(indices_all_conditions) == 0: continue
            vals_inbin = []
            for idx in indices_all_conditions:
                vals_inbin += list(stat[mask_start_idx[idx]:mask_end_idx[idx]])

            vals_inbin = np.array(vals_inbin)
            vals_above_thresh = vals_inbin[vals_inbin >= args.stat_fit_threshold]
            if not len(vals_above_thresh):
                logging.info('No triggers above threshold')
                continue
            else:
                logging.info('{} triggers out of {} above threshold'.format(
                             len(vals_above_thresh), len(vals_inbin)))
            alpha, sig_alpha = trstats.fit_above_thresh(args.fit_function,
                                 vals_above_thresh, args.stat_fit_threshold)
            fitted_cum_counts = len(vals_above_thresh) * \
                                trstats.cum_fit(args.fit_function, fitrange,
                                                alpha, args.stat_fit_threshold)
            # upper and lower 1-sigma bounds on fit are not currently plotted
            #fitted_cum_counts_plus = len(vals_above_thresh) * \
            #                             trstats.cum_fit(args.fit_function,
            #                                          fitrange, alpha + sig_alpha,
            #                                          args.stat_fit_threshold)
            #fitted_cum_counts_minus = len(vals_above_thresh) * \
            #                              trstats.cum_fit(args.fit_function, fitrange,
            #                                          alpha - sig_alpha,
            #                                          args.stat_fit_threshold)

            # Histogram of fitted values
            histcounts, edges = np.histogram(vals_inbin, bins=50)
            cum_counts = histcounts[::-1].cumsum()[::-1]
            # Plot the lines!
            ax.semilogy(edges[:-1], cum_counts, linewidth=2,
                     color=histcolors[i], alpha=0.6)
            ax.semilogy(fitrange, fitted_cum_counts, "--", color=histcolors[i],
                     label=r"$\alpha = $%.2f" % alpha)
            lgd_sub = ax.legend(fontsize='small', framealpha=0.5)

            del vals_inbin, vals_above_thresh

        maxyval = max(maxyval, cum_counts.max())
        ax.grid()

for i in range(args.split_one_nbins):
    for j in range(args.split_two_nbins):
        axes[i,j].semilogy([args.stat_fit_threshold, args.stat_fit_threshold],
                           [1, 5 * maxyval], 'k', linestyle=':', alpha=0.2)
axes[0,0].set_ylim(1, 5 * maxyval)
axes[0,0].set_xlim(minplot, maxplot)

for j in range(args.split_two_nbins):
    axes[args.split_one_nbins - 1, j].set_xlabel(args.sngl_ranking, size="large")
    if args.split_two_nbins == 1:
        break
    axes[0, j].set_xlabel(args.split_param_two + ': ' +
                          (formats[args.split_param_two] + ' to ' +
                           formats[args.split_param_two]).format(sp_two_bounds.lower()[j],
                                     sp_two_bounds.upper()[j]), size="large")
    axes[0, j].xaxis.set_label_position("top")

for i in range(args.split_one_nbins):
    if args.split_one_nbins == 1:
        axes[0, 0].set_ylabel('Cumulative number', size='large')
        break
    axes[i, 0].set_ylabel(args.split_param_one + ': ' +
                          (formats[args.split_param_one] + ' to ' +
                           formats[args.split_param_one]).format(sp_one_bounds.lower()[i],
                                      sp_one_bounds.upper()[i]) + '\ncumulative number',
                                      size="large")

fig.tight_layout(rect=(1./(args.split_two_nbins+1), 0, 1, 1))

logging.info('Saving to file ' + args.output_file)
results.save_fig_with_metadata(
    fig, args.output_file,
    title="{}: {} histogram of single detector triggers split by"
          " {} and {}".format(args.ifo, args.sngl_ranking, args.split_param_one,
                              args.split_param_two),
    caption=(r"Histogram of {} single detector {} values binned by {}, split by "
             "{} and {}, with fitted {} distribution parameterized by"
             " &alpha;".format(args.ifo, args.sngl_ranking, args.bin_param,
                               args.split_param_one, args.split_param_two,
                               args.fit_function)),
    cmd=" ".join(sys.argv)
)
logging.info('Done!')
