#!/usr/bin/env python
"""
The program combines coincident output files generated
by pycbc_coinc_findtrigs to generated a mapping between SNR and FAP, along
with producing the combined foreground and background triggers
"""
import sys
import argparse, h5py, logging, itertools, numpy
import lal
from pycbc.events import veto, coinc
import pycbc.version, pycbc.pnutils, pycbc.io

def sec_to_year(sec):
    return sec / lal.YRJUL_SI

class fw(object):
    def __init__(self, name):
        self.f = h5py.File(name, 'w')
        self.attrs = self.f.attrs

    def __setitem__(self, name, data):
        self.f.create_dataset(name, data=data, compression='gzip',
                              compression_opts=9, shuffle=True)

parser = argparse.ArgumentParser()
# General required options
parser.add_argument('--version', action='version', 
         version=pycbc.version.git_verbose_msg)
parser.add_argument('--coinc-files', nargs='+', 
         help='List of coincidence files used to calculate the FAP, FAR, etc.')
parser.add_argument('--verbose', action='count')
parser.add_argument('--cluster-window', type=float, default=10,
         help='Length of time window in seconds to cluster coinc events, [default=10s]')
parser.add_argument('--veto-window', type=float, default=.1,
         help='Time around each zerolag trigger to window out, [default=.1s]')
parser.add_argument('--output-file')
args = parser.parse_args()
pycbc.init_logging(args.verbose)

logging.info("Loading coinc triggers")    
d = pycbc.io.StatmapData(files=args.coinc_files)   
logging.info("We have %s triggers" % len(d.stat))
fore_locs = d.timeslide_id == 0
ft1, ft2 = d.time1[fore_locs], d.time2[fore_locs]
vt = (ft1 + ft2) / 2.0
veto_start, veto_end = vt - args.veto_window, vt + args.veto_window
veto_time = abs(veto.start_end_to_segments(veto_start, veto_end).coalesce())  

v1 = veto.indices_within_times(d.time1, veto_start, veto_end) 
e = d.remove(v1)

v2 = veto.indices_within_times(e.time2, veto_start, veto_end) 
e = e.remove(v2)

logging.info("Clustering coinc triggers (inclusive of zerolag)")
d = d.cluster(args.cluster_window)
fore_locs = d.timeslide_id == 0
logging.info("%s clustered foreground triggers" % fore_locs.sum())

logging.info("Clustering coinc triggers (exclusive of zerolag)")
e = e.cluster(args.cluster_window)

logging.info("Dumping foreground triggers")
f = fw(args.output_file)
f.attrs['detector_1'] = d.attrs['detector_1']
f.attrs['detector_2'] = d.attrs['detector_2']
f.attrs['timeslide_interval'] = d.attrs['timeslide_interval']

# Copy over the segment for coincs and singles
for key in d.seg.keys():
    f['segments/%s/start' % key] = d.seg[key]['start'][:]
    f['segments/%s/end' % key] = d.seg[key]['end'][:]

if fore_locs.sum() > 0:
    f['segments/foreground_veto/start'] = veto_start
    f['segments/foreground_veto/end'] = veto_end
    for k in d.data:
        f['foreground/' + k] = d.data[k][fore_locs]
else:
    # Put SOMETHING in here to avoid failures later
    f['segments/foreground_veto/start'] = numpy.array([0])
    f['segments/foreground_veto/end'] = numpy.array([0])
    for k in d.data:
        f['foreground/' + k] = numpy.array([], dtype=d.data[k].dtype)

back_locs = d.timeslide_id != 0

if (back_locs.sum()) == 0:
    logging.warn("There were no background events, so we could not assign "
                 "any statistic values")
    sys.exit()
    
logging.info("Dumping background triggers (inclusive of zerolag)")
for k in d.data:
    f['background/' + k] = d.data[k][back_locs]
    
logging.info("Dumping background triggers (exclusive of zerolag)")   
for k in e.data:
    f['background_exc/' + k] = e.data[k]

maxtime = max(d.attrs['foreground_time1'], d.attrs['foreground_time2'])
mintime = min(d.attrs['foreground_time1'], d.attrs['foreground_time2'])

maxtime_exc = maxtime - veto_time
mintime_exc = mintime - veto_time

background_time = int(maxtime / d.attrs['timeslide_interval']) * mintime
coinc_time = float(d.attrs['coinc_time'])

background_time_exc = int(maxtime_exc / d.attrs['timeslide_interval']) * mintime_exc
coinc_time_exc = coinc_time - veto_time

logging.info("Making mapping from FAN to the combined statistic")
back_stat = d.stat[back_locs]
fore_stat = d.stat[fore_locs]
back_cnum, fnlouder = coinc.calculate_n_louder(back_stat, fore_stat, 
                                               d.decimation_factor[back_locs])       

back_cnum_exc, fnlouder_exc = coinc.calculate_n_louder(e.stat, fore_stat, 
                                               e.decimation_factor)         


f['background/ifar'] = sec_to_year(background_time / (back_cnum + 1))  
f['background_exc/ifar'] = sec_to_year(background_time_exc / (back_cnum_exc + 1))

f.attrs['background_time'] = background_time
f.attrs['foreground_time'] = coinc_time
f.attrs['background_time_exc'] = background_time_exc
f.attrs['foreground_time_exc'] = coinc_time_exc

logging.info("calculating ifar/fap values")

if fore_locs.sum() > 0:
    ifar = background_time / (fnlouder + 1)
    fap = 1 - numpy.exp(- coinc_time / ifar)
    f['foreground/ifar'] = sec_to_year(ifar)
    f['foreground/fap'] = fap

    ifar_exc = background_time_exc / (fnlouder_exc + 1)
    fap_exc = 1 - numpy.exp(- coinc_time_exc / ifar_exc)
    f['foreground/ifar_exc'] = sec_to_year(ifar_exc)
    f['foreground/fap_exc'] = fap_exc
else:
    f['foreground/ifar'] = numpy.array([])
    f['foreground/fap'] = numpy.array([])
    f['foreground/ifar_exc'] = numpy.array([])
    f['foreground/fap_exc'] = numpy.array([])

if 'name' in d.attrs:
    f.attrs['name'] = d.attrs['name']

logging.info("Done") 
