#!/usr/bin/env python
""" Bin triggers by their dq value and calculate trigger rates in each bin
"""
import logging
import argparse

import numpy as np
import h5py as h5

from ligo.segments import segmentlist

import pycbc
from pycbc.events import stat as pystat
from pycbc.events.veto import (select_segments_by_definer,
                               start_end_to_segments,
                               segments_to_start_end)
from pycbc.types.optparse import MultiDetOptionAction
from pycbc.io.hdf import HFile, SingleDetTriggers
from pycbc.version import git_verbose_msg as version

parser = argparse.ArgumentParser(description=__doc__)
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--version', action='version', version=version)
parser.add_argument("--template-bins-file", required=True)
parser.add_argument("--trig-file", required=True)
parser.add_argument("--flag-file", required=True)
parser.add_argument("--flag-name", required=True)
parser.add_argument("--analysis-segment-file", required=True)
parser.add_argument("--analysis-segment-name", required=True)
parser.add_argument("--gating-windows", nargs='+',
                    action=MultiDetOptionAction,
                    help="Seconds to reweight before and after the central"
                         "time of each gate. Given as detector-values pairs, "
                         "e.g. H1:-1,2.5 L1:-1,2.5 V1:0,0")
parser.add_argument("--stat-threshold", type=float, default=1.,
                    help="Only consider triggers with --sngl-ranking value "
                    "above this threshold")
parser.add_argument("--output-file", required=True)

pystat.insert_statistic_option_group(
    parser, default_ranking_statistic='single_ranking_only')
args = parser.parse_args()
pycbc.init_logging(args.verbose)

logging.info('Start')

ifo, flag_name = args.flag_name.split(':')

# Setup a data mask to remove any triggers with SNR below threshold
# This works as a pre-filter as SNR is always greater than or equal
# to sngl_ranking, except in the psdvar case, where it could increase.

with HFile(args.trig_file, 'r') as trig_file:
    n_triggers_orig = trig_file[f'{ifo}/snr'].size
    logging.info("Trigger file has %d triggers", n_triggers_orig)
    logging.info('Generating trigger mask')
    if f'{ifo}/psd_var_val' in trig_file:
        idx, _, _ = trig_file.select(
            lambda snr, psdvar: snr / psdvar ** 0.5 >= args.stat_threshold,
            f'{ifo}/snr',
            f'{ifo}/psd_var_val',
            return_indices=True
        )
    else:
        # psd_var_val may not have been calculated
        idx, _ = trig_file.select(
            lambda snr: snr >= args.stat_threshold,
            f'{ifo}/snr',
            return_indices=True
        )
    data_mask = np.zeros(n_triggers_orig, dtype=bool)
    data_mask[idx] = True

    # also get the gated times while we have the file open
    gate_times = []
    if args.gating_windows:
        logging.info('Getting gated times')
        try:
            gating_types = trig_file[f'{ifo}/gating'].keys()
            for gt in gating_types:
                gate_times += list(trig_file[f'{ifo}/gating/{gt}/time'][:])
            gate_times = np.unique(gate_times)
        except KeyError:
            logging.warning('No gating found in trigger file')

logging.info("Getting %s triggers from file with pre-cut SNR > %.3f",
             idx.size, args.stat_threshold)

trigs = SingleDetTriggers(
    args.trig_file,
    None,
    None,
    None,
    None,
    ifo,
    premask=data_mask
)

# Extract the data we actually need from the data structure:
tmplt_ids = trigs.template_id
trig_times = trigs.end_time
stat = trigs.get_ranking(args.sngl_ranking)

n_triggers = tmplt_ids.size

logging.info("Applying %s > %.3f cut", args.sngl_ranking,
             args.stat_threshold)
keep = stat >= args.stat_threshold
tmplt_ids = tmplt_ids[keep]
trig_times = trig_times[keep]
logging.info("Removed %d triggers, %d remain",
             n_triggers - tmplt_ids.size, tmplt_ids.size)

# Get the template bins
bin_tids_dict = {}
with h5.File(args.template_bins_file, 'r') as f:
    ifo_grp = f[ifo]
    for bin_name in ifo_grp.keys():
        bin_tids_dict[bin_name] = ifo_grp[bin_name]['tids'][:]


# get analysis segments
analysis_segs = select_segments_by_definer(
    args.analysis_segment_file,
    segment_name=args.analysis_segment_name,
    ifo=ifo)

livetime = abs(analysis_segs)

# get flag segments
flag_segs = select_segments_by_definer(args.flag_file,
                                       segment_name=flag_name,
                                       ifo=ifo)

# construct gate segments
gating_segs = segmentlist([])
if args.gating_windows:
    gating_windows = args.gating_windows[ifo].split(',')
    gate_before = float(gating_windows[0])
    gate_after = float(gating_windows[1])
    if gate_before > 0 or gate_after < 0:
        raise ValueError("Gating window values must be negative "
                         "before gates and positive after gates.")
    if not (gate_before == 0 and gate_after == 0):
        gating_segs = start_end_to_segments(
                gate_times + gate_before,
                gate_times + gate_after
        ).coalesce()

# make segments into mutually exclusive dq states
gating_segs = gating_segs & analysis_segs
flag_segs = flag_segs & analysis_segs

dq_state_segs_dict = {}
dq_state_segs_dict[2] = gating_segs
dq_state_segs_dict[1] = flag_segs - gating_segs
dq_state_segs_dict[0] = analysis_segs - flag_segs - gating_segs


# utility function to get the dq state at a given time
def dq_state_at_time(t):
    for state, segs in dq_state_segs_dict.items():
        if t in segs:
            return state
    return None


# compute and save results
with h5.File(args.output_file, 'w') as f:
    ifo_grp = f.create_group(ifo)
    all_bin_grp = ifo_grp.create_group('bins')
    all_dq_grp = ifo_grp.create_group('dq_segments')

    # setup data for each template bin
    for bin_name, bin_tids in bin_tids_dict.items():
        bin_grp = all_bin_grp.create_group(bin_name)
        bin_grp['tids'] = bin_tids

        # get the dq states of the triggers in this bin
        inbin = np.isin(tmplt_ids, bin_tids)
        trig_times_bin = trig_times[inbin]
        trig_states = np.array([dq_state_at_time(t) for t in trig_times_bin])

        # calculate the dq rates in this bin
        dq_rates = np.zeros(3, dtype=np.float64)
        for state, segs in dq_state_segs_dict.items():
            frac_eff = np.mean(trig_states == state)
            frac_dt = abs(segs) / livetime
            dq_rates[state] = frac_eff / frac_dt
        bin_grp['dq_rates'] = dq_rates

    # save dq state segments
    for dq_state, segs in dq_state_segs_dict.items():
        name = f'dq_state_{dq_state}'
        dq_grp = all_dq_grp.create_group(name)
        starts, ends = segments_to_start_end(segs)
        dq_grp['segment_starts'] = starts
        dq_grp['segment_ends'] = ends

    f.attrs['stat'] = f'{ifo}-dq_stat_info'

logging.info('Done!')
