#!/bin/env python
# Copyright (C) 2015 Alexander Harvey Nitz
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
""" Followup foreground events
"""
import os, argparse, numpy, logging, h5py, copy
import pycbc.workflow as wf
from pycbc.types import MultiDetOptionAction
from pycbc.results import layout
from pycbc.detector import Detector
import pycbc.workflow.minifollowups as mini
import pycbc.workflow.pegasus_workflow as wdax
import pycbc.version, pycbc.pnutils
from pycbc.io.hdf import SingleDetTriggers

def to_file(path, ifo=None):
    fil = wdax.File(os.path.basename(path))
    fil.ifo = ifo
    path = os.path.abspath(path)
    fil.PFN(path, 'local')
    return fil

parser = argparse.ArgumentParser(description=__doc__[1:])
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg) 
parser.add_argument('--workflow-name', default='my_unamed_run')
parser.add_argument("-d", "--output-dir", default=None,
                    help="Path to output directory.")
parser.add_argument('--bank-file',
                    help="HDF format template bank file")
parser.add_argument('--injection-file',
                    help="HDF format injection results file")
parser.add_argument('--injection-xml-file',
                    help="XML format injection file")
parser.add_argument('--single-detector-triggers', nargs='+', action=MultiDetOptionAction,
                    help="HDF format merged single detector trigger files")
parser.add_argument('--inspiral-segments',
                    help="xml segment files containing the inspiral analysis times")
parser.add_argument('--inspiral-data-read-name',
                    help="Name of inspiral segmentlist containing data read in "
                         "by each analysis job.")
parser.add_argument('--inspiral-data-analyzed-name',
                    help="Name of inspiral segmentlist containing data "
                         "analyzed by each analysis job.")
parser.add_argument('--inj-window', type=int, default=0.5,
                    help="Time window in which to look for injection triggers")
parser.add_argument('--output-map')
parser.add_argument('--output-file')
parser.add_argument('--tags', nargs='+', default=[])
wf.add_workflow_command_line_group(parser)
args = parser.parse_args()

logging.basicConfig(format='%(asctime)s:%(levelname)s : %(message)s', 
                    level=logging.INFO)

workflow = wf.Workflow(args, args.workflow_name)

wf.makedir(args.output_dir)
             
# create a FileList that will contain all output files
layouts = []

tmpltbank_file = to_file(args.bank_file)
injection_file = to_file(args.injection_file)
injection_xml_file = to_file(args.injection_xml_file)
insp_segs = to_file(args.inspiral_segments)

single_triggers = []
for ifo in args.single_detector_triggers:
    fname = args.single_detector_triggers[ifo]
    single_triggers.append(to_file(fname, ifo=ifo))

f = h5py.File(args.injection_file, 'r')
missed = f['missed/after_vetoes'][:]

num_events = int(workflow.cp.get_opt_tags('workflow-injection_minifollowups', 'num-events', ''))

try:
    optimal_snr =[f['injections/optimal_snr_1'][:][missed],
                  f['injections/optimal_snr_2'][:][missed]]
    dec_snr = numpy.minimum(optimal_snr[0], optimal_snr[1])
    sorting = dec_snr.argsort()
    sorting = sorting[::-1]
except:
    # Fall back to effective distance if optimal SNR not available
    eff_dist = {'H1': f['injections/eff_dist_h'][:][missed],
               'L1': f['injections/eff_dist_l'][:][missed],
               'V1': f['injections/eff_dist_v'][:][missed],
               }
    dec_dist = numpy.maximum(eff_dist[single_triggers[0].ifo], 
                             eff_dist[single_triggers[1].ifo])

    mchirp, eta = pycbc.pnutils.mass1_mass2_to_mchirp_eta(\
                                              f['injections/mass1'][:][missed],
                                              f['injections/mass2'][:][missed])

    dec_chirp_dist = pycbc.pnutils.chirp_distance(dec_dist, mchirp)

    sorting = dec_chirp_dist.argsort()

if len(missed) < num_events:
    num_events = len(missed)

# loop over number of loudest events to be followed up
for num_event in range(num_events):
    files = wf.FileList([])
    
    injection_index = missed[sorting][num_event]
    time = f['injections/end_time'][injection_index]
    
    ifo_times = ''
    inj_params = {}
    for val in ['mass1', 'mass2', 'spin1z', 'spin2z', 'end_time']:
        inj_params[val] = f['injections/%s' %(val,)][injection_index]
    for single in single_triggers:
        ifo = single.ifo
        det = Detector(ifo)
        lon = f['injections/longitude'][injection_index]
        lat = f['injections/latitude'][injection_index]
        ifo_time = time + det.time_delay_from_earth_center(lon, lat, time)
        ifo_times += ' %s:%s ' % (ifo, ifo_time) 
        inj_params[ifo+'_end_time'] = ifo_time
          
    layouts += [(mini.make_inj_info(workflow, injection_file, injection_index, num_event,
                               args.output_dir, tags=args.tags + [str(num_event)])[0],)]   
    files += mini.make_trigger_timeseries(workflow, single_triggers,
                              ifo_times, args.output_dir,
                              tags=args.tags + [str(num_event)])
    
    for single in single_triggers:
        files += mini.make_singles_timefreq(workflow, single, tmpltbank_file, 
                                time - 10, time + 10, args.output_dir,
                                tags=args.tags + [str(num_event)])

    files += mini.make_single_template_plots(workflow, insp_segs,
                            args.inspiral_data_read_name,
                            args.inspiral_data_analyzed_name, inj_params,
                            args.output_dir, inj_file=injection_xml_file,
                            tags=args.tags+['INJ_PARAMS',str(num_event)],
                            params_str='injection parameters as template, ' +\
                                       'here the injection is made as normal',
                            use_exact_inj_params=True)

    files += mini.make_single_template_plots(workflow, insp_segs,
                            args.inspiral_data_read_name,
                            args.inspiral_data_analyzed_name, inj_params,
                            args.output_dir, inj_file=injection_xml_file,
                            tags=args.tags+['INJ_PARAMS_INVERTED',
                                            str(num_event)],
                            params_str='injection parameters as template, ' +\
                                       'here the injection is made inverted',
                            use_exact_inj_params=True)

    files += mini.make_single_template_plots(workflow, insp_segs,
                            args.inspiral_data_read_name,
                            args.inspiral_data_analyzed_name, inj_params,
                            args.output_dir, inj_file=injection_xml_file,
                            tags=args.tags+['INJ_PARAMS_NOINJ',
                                            str(num_event)],
                            params_str='injection parameters, here no ' +\
                                       'injection was actually performed',
                            use_exact_inj_params=True)

    for curr_ifo in args.single_detector_triggers:
        single_fname = args.single_detector_triggers[curr_ifo]
        hd_sngl = SingleDetTriggers(single_fname, args.bank_file, None, None,
                                    None, curr_ifo)
        end_times = hd_sngl.end_time
        # Use SNR here or NewSNR??
        snr = hd_sngl.snr
        lgc_mask = abs(end_times - inj_params['end_time']) < args.inj_window
        
        if len(snr[lgc_mask]) == 0:
            continue

        snr_idx = hd_sngl.mask[lgc_mask][snr[lgc_mask].argmax()]
        hd_sngl.mask = hd_sngl.mask[snr_idx]
        curr_params = copy.deepcopy(inj_params)
        curr_params['mass1'] = hd_sngl.mass1
        curr_params['mass2'] = hd_sngl.mass2
        curr_params['spin1z'] = hd_sngl.spin1z
        curr_params['spin2z'] = hd_sngl.spin2z
        # don't require precessing template info if not present
        try:
            curr_params['spin1x'] = hd_sngl.spin1x
            curr_params['spin2x'] = hd_sngl.spin2x
            curr_params['spin1y'] = hd_sngl.spin1y
            curr_params['spin2y'] = hd_sngl.spin2y
            curr_params['inclination'] = hd_sngl.inclination
        except KeyError:
            pass
        try:
            # Only present for precessing search
            curr_params['u_vals'] = hd_sngl.u_vals
        except:
            pass

        curr_tags = ['TMPLT_PARAMS_%s' %(curr_ifo,)]
        curr_tags += [str(num_event)]
        files += mini.make_single_template_plots(workflow, insp_segs,
                                args.inspiral_data_read_name,
                                args.inspiral_data_analyzed_name, curr_params,
                                args.output_dir, inj_file=injection_xml_file,
                                tags=args.tags + curr_tags,
                                params_str='loudest template in %s' % curr_ifo )

    layouts += list(layout.grouper(files, 2))
    num_event += 1

workflow.save(filename=args.output_file, output_map=args.output_map)
layout.two_column_layout(args.output_dir, layouts)
