#!/bin/env python
# Copyright (C) 2015 Alexander Harvey Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Followup foreground events
"""
import os
import sys
import argparse
import logging
import re
import h5py

import pycbc.workflow as wf
from pycbc.results import layout
from pycbc.types import MultiDetOptionAction
from pycbc.events import select_segments_by_definer, coinc
from pycbc.io import get_all_subkeys
import pycbc.workflow.minifollowups as mini
from pycbc.workflow.core import resolve_url_to_file
import pycbc.version

parser = argparse.ArgumentParser(description=__doc__[1:])
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--bank-file',
                    help="HDF format template bank file")
parser.add_argument('--statmap-file',
                    help="HDF format clustered coincident trigger result file")
parser.add_argument('--single-detector-triggers', nargs='+', action=MultiDetOptionAction,
                    help="HDF format merged single detector trigger files")
parser.add_argument('--inspiral-segments',
                    help="xml segment files containing the inspiral analysis times")
parser.add_argument('--inspiral-data-read-name',
                    help="Name of inspiral segmentlist containing data read in "
                         "by each analysis job.")
parser.add_argument('--inspiral-data-analyzed-name',
                    help="Name of inspiral segmentlist containing data "
                         "analyzed by each analysis job.")
parser.add_argument('--analysis-category', type=str, required=False,
                    default='background_exc',
                    choices = ['foreground', 'background', 'background_exc'],
                    help='Designates whether to look at foreground triggers '
                         'background triggers (including "little dogs") '
                         'or background triggers (with "little dogs" removed)')
parser.add_argument('--sort-variable', default='ifar',
                    help='Which subgroup of --analysis-category to use for '
                         'sorting. Default=ifar')
parser.add_argument('--sort-order', default='descending',
                    choices=['ascending','descending'],
                    help='Which direction to use when sorting on '
                         '--sort-variable. Default=descending')

wf.add_workflow_command_line_group(parser)
wf.add_workflow_settings_cli(parser, include_subdax_opts=True)
args = parser.parse_args()

logging.basicConfig(format='%(asctime)s:%(levelname)s : %(message)s',
                    level=logging.INFO)

workflow = wf.Workflow(args)

wf.makedir(args.output_dir)

# create a FileList that will contain all output files
layouts = []

tmpltbank_file = resolve_url_to_file(os.path.abspath(args.bank_file))
coinc_file = resolve_url_to_file(os.path.abspath(args.statmap_file))
insp_segs = resolve_url_to_file(os.path.abspath(args.inspiral_segments))

single_triggers = []
fsdt = {}
insp_data_seglists = {}
insp_analysed_seglists = {}
for ifo in args.single_detector_triggers:
    fname = args.single_detector_triggers[ifo]
    strig_file = resolve_url_to_file(os.path.abspath(fname),
                                     attrs={'ifos': ifo})
    single_triggers.append(strig_file)
    fsdt[ifo] = h5py.File(args.single_detector_triggers[ifo], 'r')
    insp_data_seglists[ifo] = select_segments_by_definer(
        args.inspiral_segments,
        segment_name=args.inspiral_data_read_name,
        ifo=ifo)
    insp_analysed_seglists[ifo] = select_segments_by_definer(
        args.inspiral_segments,
        segment_name=args.inspiral_data_analyzed_name,
        ifo=ifo)
    # NOTE: make_singles_timefreq needs a coalesced set of segments. If this is
    #       being used to determine command-line options for other codes,
    #       please think if that code requires coalesced, or not, segments.
    insp_data_seglists[ifo].coalesce()
    insp_analysed_seglists[ifo].coalesce()

num_events = int(workflow.cp.get_opt_tags('workflow-minifollowups', 'num-events', ''))
f = h5py.File(args.statmap_file, 'r')
file_val = args.analysis_category
stat = f['{}/stat'.format(file_val)][:]

if args.sort_variable not in f[file_val]:
    all_datasets = [re.sub(file_val, '', ds).strip('/')
                    for ds in get_all_subkeys(f, file_val)]
    raise KeyError(f'Sort variable {args.sort_variable} not in {file_val}: sort'
                   f'choices in {file_val} are ' + ', '.join(all_datasets))

# In case we are doing background minifollowup with repeated events,
# we must include the ordering / template / trigger / time info for
# _many_ more events to make sure we get enough
events_to_read = num_events * 100

# We've asked for more events than there are!
if len(stat) < num_events:
    num_events = len(stat)
if len(stat) < events_to_read:
    events_to_read = len(stat)

# Get the indices of the events we are considering in the order specified
sorting = f[file_val + '/' + args.sort_variable][:].argsort()
if args.sort_order == 'descending':
    sorting = sorting[::-1]
event_idx = sorting[0:events_to_read]
stat = stat[event_idx]

# Save the time / trigger / template ids for the events
times = {}
tids = {}
ifo_list = f.attrs['ifos'].split(' ')
f_cat = f[file_val]
for ifo in ifo_list:
    times[ifo] = f_cat[ifo]['time'][:][event_idx]
    tids[ifo] = f_cat[ifo]['trigger_id'][:][event_idx]
bank_ids = f_cat['template_id'][:][event_idx]
f.close()

bank_data = h5py.File(args.bank_file, 'r')

# loop over number of loudest events to be followed up
event_times = {}
skipped_data = []
event_count = 0
curr_idx = -1
while event_count < num_events and curr_idx < (events_to_read - 1):
    curr_idx += 1
    files = wf.FileList([])
    duplicate = False

    # Times and trig ids for trigger timeseries plot
    ifo_times_strings = []
    ifo_tids_strings = []

    # Mean time to use as reference for zerolag
    mtime = coinc.mean_if_greater_than_zero(
        [times[ifo][curr_idx] for ifo in times])[0]

    for ifo in times:
        plottime = times[ifo][curr_idx]
        if plottime == -1:  # Ifo is not in coinc
            # Use mean time instead and don't plot special trigger
            plottime = mtime
        else:  # Do plot special trigger
            ifo_tids_strings += ['%s:%s' % (ifo, tids[ifo][curr_idx])]
        ifo_times_strings += ['%s:%s' % (ifo, plottime)]

        # For background do not want to follow up 10 background coincs with the
        # same event in ifo 1 and different events in ifo 2, so record times
        if ifo not in event_times:
            event_times[ifo] = []
        # Only do this for background coincident triggers & not sentinel time -1
        if 'background' in args.analysis_category and \
                times[ifo][curr_idx] != -1 and \
                int(times[ifo][curr_idx]) in event_times[ifo]:
            skipped_data.append((ifo, int(times[ifo][curr_idx])))
            duplicate = True
            break

    if duplicate:
        continue
    for ifo in times:
        event_times[ifo].append(int(times[ifo][curr_idx]))

    ifo_times = ' '.join(ifo_times_strings)
    ifo_tids = ' '.join(ifo_tids_strings)

    event_count += 1
    if skipped_data:
        layouts += (mini.make_skipped_html(
                        workflow,
                        skipped_data,
                        args.output_dir,
                        tags=[f'SKIP_{event_count}']),)
        skipped_data = []

    bank_id = bank_ids[curr_idx]

    layouts += (mini.make_coinc_info(workflow, single_triggers, tmpltbank_file,
                              coinc_file, args.output_dir, n_loudest=curr_idx,
                              sort_order=args.sort_order, sort_var=args.sort_variable,
                              tags=args.tags + [str(event_count)]),)
    files += mini.make_trigger_timeseries(workflow, single_triggers,
                             ifo_times, args.output_dir, special_tids=ifo_tids,
                             tags=args.tags + [str(event_count)])

    params = mini.get_single_template_params(
        curr_idx,
        times,
        bank_data,
        bank_ids[curr_idx],
        fsdt,
        tids
    )

    _, sngl_tmplt_plots = mini.make_single_template_plots(
        workflow,
        insp_segs,
        args.inspiral_data_read_name,
        args.inspiral_data_analyzed_name,
        params,
        args.output_dir,
        data_segments=insp_data_seglists,
        tags=args.tags + [str(event_count)]
    )
    files += sngl_tmplt_plots


    for single in single_triggers:
        time = times[single.ifo][curr_idx]
        if time==-1:
            # If this detector did not trigger, still make the plot, but use
            # the average time of detectors which did trigger
            time = coinc.mean_if_greater_than_zero([times[sngl.ifo][curr_idx]
                                                    for sngl in single_triggers])[0]
        for seg in insp_analysed_seglists[single.ifo]:
            if time in seg:
                files += mini.make_singles_timefreq(workflow, single, tmpltbank_file,
                                time, args.output_dir,
                                data_segments=insp_data_seglists[single.ifo],
                                tags=args.tags + [str(event_count)])
                files += mini.make_qscan_plot\
                    (workflow, single.ifo, time, args.output_dir,
                     data_segments=insp_data_seglists[single.ifo],
                     tags=args.tags + [str(event_count)])
                break
        else:
            logging.info(f'Trigger time {time} is not valid in '
                         f'{single.ifo}, skipping singles plots')

    layouts += list(layout.grouper(files, 2))

workflow.save()
layout.two_column_layout(args.output_dir, layouts)
