#!/home/travis/build/pycbc-build/environment/bin/python
"""
The program combines coincident output files generated
by pycbc_coinc_findtrigs to generated a mapping between SNR and FAP, along
with producing the combined foreground and background triggers. It also has 
the capability of doing hierarchical removal of foreground triggers that are 
louder than all of the background triggers. We use this to properly assess 
the FANs of any other gravitational waves in the dataset.
"""
import argparse, copy, h5py, itertools
import lal, logging, numpy 
from pycbc.events import veto, coinc
import pycbc.version, pycbc.pnutils, pycbc.io
import sys

def sec_to_year(sec):
    return sec / lal.YRJUL_SI

class fw(object):
    def __init__(self, name):
        self.f = h5py.File(name, 'w')
        self.attrs = self.f.attrs

    def __setitem__(self, name, data):
        # Make a new item if isn't in the hdf file
        if not name in self.f:
            self.f.create_dataset(name, data=data, compression="gzip",
                                  compression_opts=9, shuffle=True,
                                  maxshape=data.shape)
        # Else reassign values
        else:
            self.f[name][:] = data

    def __getitem__(self, *args):
        return self.f.__getitem__(*args)

parser = argparse.ArgumentParser()
# General required options
parser.add_argument('--version', action='version', 
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--coinc-files', nargs='+', 
                    help='List of coincidence files used to calculate the '
                         'FAP, FAR, etc.')
parser.add_argument('--verbose', action='count')
parser.add_argument('--cluster-window', type=float, default=10,
                    help='Length of time window in seconds to cluster coinc '
                         'events, [default=10s]')
parser.add_argument('--veto-window', type=float, default=.1,
                    help='Time around each zerolag trigger to window out, '
                         '[default=.1s]')
parser.add_argument('--hierarchical-removal-window', type=float, default=1.0,
                    help='Time around each trigger to window out for a very '
                         'louder trigger in the hierarchical removal '
                         'procedure. [default=1.0s]')
parser.add_argument('--max-hier-removal', type=int, default=0,
                    help='Maximum amount of hierarchical removals to carry '
                         'out when hierarchical removal is desired. Choose '
                         '-1 for continuous hiearchical removal until no '
                         'foreground triggers are louder than the inclusive '
                         'background. Choose 0 to do no hierarchical '
                         'removals.  Choose 1 to do at most 1 hierarchical '
                         'removal, and so on. [default=0]')
parser.add_argument('--output-file')
args = parser.parse_args()
pycbc.init_logging(args.verbose)

logging.info("Loading coinc triggers")    
all_trigs = pycbc.io.StatmapData(files=args.coinc_files)

logging.info("We have %s triggers" % len(all_trigs.stat))
fore_locs = all_trigs.timeslide_id == 0

# Foreground trigger times for ifo.
fore_time1 = all_trigs.time1[fore_locs]
fore_time2 = all_trigs.time2[fore_locs]

# Average times of triggers from ifo1 and ifo2
ave_fore_time = (fore_time1 + fore_time2) / 2.0

# Remove start and end time around every average foreground trigger time to
# window around.
remove_start_time = ave_fore_time - args.veto_window
remove_end_time = ave_fore_time + args.veto_window

# The time contained between segments around the times contained between each
# element of remove_start_time and remove_end_time.
veto_time = abs(veto.start_end_to_segments(remove_start_time,
                                           remove_end_time).coalesce())

# Veto indices from list of triggers for times in ifo 1&2 around the window
# times. This gives exclusive background triggers.
veto_indices1 = veto.indices_within_times(all_trigs.time1, remove_start_time,
                                          remove_end_time)

exc_zero_trigs = all_trigs.remove(veto_indices1)

veto_indices2 = veto.indices_within_times(exc_zero_trigs.time2,
                                          remove_start_time, remove_end_time)

exc_zero_trigs = exc_zero_trigs.remove(veto_indices2)

logging.info("Clustering coinc triggers (inclusive of zerolag)")
all_trigs = all_trigs.cluster(args.cluster_window)

# Return an array of true or false if the trigger has not been time-slid.
fore_locs = all_trigs.timeslide_id == 0
logging.info("%s clustered foreground triggers" % fore_locs.sum())

logging.info("Clustering coinc triggers (exclusive of zerolag)")
exc_zero_trigs = exc_zero_trigs.cluster(args.cluster_window)

logging.info("Dumping foreground triggers")
f = fw(args.output_file)
f.attrs['detector_1'] = all_trigs.attrs['detector_1']
f.attrs['detector_2'] = all_trigs.attrs['detector_2']
f.attrs['timeslide_interval'] = all_trigs.attrs['timeslide_interval']

# Copy over the segment for coincs and singles
for key in all_trigs.seg.keys():
    f['segments/%s/start' % key] = all_trigs.seg[key]['start'][:]
    f['segments/%s/end' % key] = all_trigs.seg[key]['end'][:]

if fore_locs.sum() > 0:
    f['segments/foreground_veto/start'] = remove_start_time
    f['segments/foreground_veto/end'] = remove_end_time
    for k in all_trigs.data:
        f['foreground/' + k] = all_trigs.data[k][fore_locs]
else:
    # Put SOMETHING in here to avoid failures later
    f['segments/foreground_veto/start'] = numpy.array([0])
    f['segments/foreground_veto/end'] = numpy.array([0])
    for k in all_trigs.data:
        f['foreground/' + k] = numpy.array([], dtype=all_trigs.data[k].dtype)

# If a particular index of all_trigs.timeslide_id isn't 0, evaluate true.
# List of locations that is background.
back_locs = all_trigs.timeslide_id != 0

if (back_locs.sum()) == 0:
    logging.warn("There were no background events, so we could not assign "
                 "any statistic values")
    sys.exit()

logging.info("Dumping background triggers (inclusive of zerolag)")
for k in all_trigs.data:
    f['background/' + k] = all_trigs.data[k][back_locs]
    
logging.info("Dumping background triggers (exclusive of zerolag)")   
for k in exc_zero_trigs.data:
    f['background_exc/' + k] = exc_zero_trigs.data[k]

maxtime = max(all_trigs.attrs['foreground_time1'], all_trigs.attrs['foreground_time2'])
mintime = min(all_trigs.attrs['foreground_time1'], all_trigs.attrs['foreground_time2'])

maxtime_exc = maxtime - veto_time
mintime_exc = mintime - veto_time

background_time = int(maxtime / all_trigs.attrs['timeslide_interval']) * mintime
coinc_time = float(all_trigs.attrs['coinc_time'])

background_time_exc = int(maxtime_exc / all_trigs.attrs['timeslide_interval']) * mintime_exc
coinc_time_exc = coinc_time - veto_time

logging.info("Making mapping from FAN to the combined statistic")

# Ranking statistic of foreground and background
back_stat = all_trigs.stat[back_locs]
fore_stat = all_trigs.stat[fore_locs]

# Cumulative array of inclusive background triggers and the number of
# inclusive background triggers louder than each foreground trigger.
back_cnum, fnlouder = coinc.calculate_n_louder(back_stat, fore_stat, 
                                               all_trigs.decimation_factor[back_locs])

# Cumulative array of exclusive background triggers and the number
# of exclusive background triggers louder than each foreground trigger.
back_cnum_exc, fnlouder_exc = coinc.calculate_n_louder(exc_zero_trigs.stat,
                                                       fore_stat,
                                                       exc_zero_trigs.decimation_factor)

f['background/ifar'] = sec_to_year(background_time / (back_cnum + 1))  
f['background_exc/ifar'] = sec_to_year(background_time_exc / (back_cnum_exc + 1))

f.attrs['background_time'] = background_time

f.attrs['foreground_time'] = coinc_time
f.attrs['background_time_exc'] = background_time_exc
f.attrs['foreground_time_exc'] = coinc_time_exc

logging.info("calculating ifar/fap values")

if fore_locs.sum() > 0:
    ifar = background_time / (fnlouder + 1)
    fap = 1 - numpy.exp(- coinc_time / ifar)
    f['foreground/ifar'] = sec_to_year(ifar)
    f['foreground/fap'] = fap

    ifar_exc = background_time_exc / (fnlouder_exc + 1)
    fap_exc = 1 - numpy.exp(- coinc_time_exc / ifar_exc)
    f['foreground/ifar_exc'] = sec_to_year(ifar_exc)
    f['foreground/fap_exc'] = fap_exc
else:
    f['foreground/ifar'] = numpy.array([])
    f['foreground/fap'] = numpy.array([])
    f['foreground/ifar_exc'] = numpy.array([])
    f['foreground/fap_exc'] = numpy.array([])

if 'name' in all_trigs.attrs:
    f.attrs['name'] = all_trigs.attrs['name']

# Incorporate hierarchical removal for any other loud triggers
logging.info("Beginning hierarchical removal of foreground triggers.")

# Step 1: Create a copy of foreground trigger ranking statistic for reference
#         in the hierarchical removal while loop when updating ifar and fap of
#         hierarchically removed foreground triggers.

# Set an index to keep track of how many hierarchical removals we want to do.
h_iterations = 0

orig_fore_stat = fore_stat

# Step 2 : Loop until we don't have to hierarchically remove anymore. This
#          will happen when fnlouder has no elements that equal 0.

while numpy.any(fnlouder == 0):
    # If the user wants to stop doing hierarchical removals after a set
    # number of iterations then break when that happens.
    if (h_iterations == args.max_hier_removal): 
        break

    # Write foreground trigger info before hierarchical removals for
    # downstream codes.
    if h_iterations == 0:
        f['background_h%s/stat' % h_iterations] = back_stat 
        f['background_h%s/ifar' % h_iterations] = sec_to_year(background_time / (back_cnum + 1))
        f['foreground_h%s/stat' % h_iterations] = fore_stat
        f['foreground_h%s/ifar' % h_iterations] = sec_to_year(ifar)
        f['foreground_h%s/fap' % h_iterations] = fap

    # Add the iteration number of hierarchical removals done.
    h_iterations += 1

    # Among foreground triggers, find the one with the largest ranking
    # statistic and mark it for removal.
    max_stat_idx = fore_stat.argmax()

    # Step 3: Remove that trigger from the list of zerolag triggers

    # Find the index of the loud foreground trigger to remove next. And find
    # the index in the list of original foreground triggers.
    rm_trig_idx = numpy.where(all_trigs.stat[:] == fore_stat[max_stat_idx])[0][0]
    orig_fore_idx = numpy.where(orig_fore_stat == fore_stat[max_stat_idx])[0][0]

    # Store any foreground trigger's information that we want to
    # hierarchically remove.

    f['foreground/ifar'][orig_fore_idx] = sec_to_year(ifar[max_stat_idx])
    f['foreground/fap'][orig_fore_idx] = fap[max_stat_idx]

    logging.info("Removing foreground trigger that is louder than the inclusive background.")

    # Remove the foreground trigger and all of the background triggers that
    # are associated with it.

    ave_rm_time = (all_trigs.time1[rm_trig_idx] + all_trigs.time2[rm_trig_idx]) / 2.0

    ind_to_rm_ifo1 = veto.indices_within_times(all_trigs.time1,
                              [ave_rm_time - args.hierarchical_removal_window],
                              [ave_rm_time + args.hierarchical_removal_window])
    ind_to_rm_ifo2 = veto.indices_within_times(all_trigs.time2,
                              [ave_rm_time - args.hierarchical_removal_window],
                              [ave_rm_time + args.hierarchical_removal_window])

    indices_to_rm = numpy.concatenate([ind_to_rm_ifo1, ind_to_rm_ifo2])

    all_trigs = all_trigs.remove(indices_to_rm)

    fore_locs = all_trigs.timeslide_id == 0
    # The foreground trigger has been removed, continue with typical statmap operations. 

    # Calculate the change to foreground trigger time and vetoed out time
    fore_time1 = all_trigs.time1[fore_locs]
    fore_time2 = all_trigs.time2[fore_locs]

    ave_fore_time = (fore_time1 + fore_time2) / 2.0

    remove_start_time = ave_fore_time - args.veto_window
    remove_end_time = ave_fore_time + args.veto_window

    logging.info("We have %s triggers after hierarchical removal." % len(all_trigs.stat))

    # Step 4: Re cluster the triggers and calculate the inclusive ifar/fap
    logging.info("Clustering coinc triggers (inclusive of zerolag)")
    all_trigs = all_trigs.cluster(args.cluster_window)
    
    fore_locs = all_trigs.timeslide_id == 0

    logging.info("%s clustered foreground triggers" % fore_locs.sum())

    logging.info("%s hierarchically removed foreground trigger(s)" % h_iterations)

    back_locs = all_trigs.timeslide_id != 0 

    logging.info("Dumping foreground triggers")

    logging.info("Dumping background triggers (inclusive of zerolag)")
    for k in all_trigs.data:
         f['background_h%s/' %h_iterations + k] = all_trigs.data[k][back_locs]

    maxtime = max(all_trigs.attrs['foreground_time1'], all_trigs.attrs['foreground_time2'])
    mintime = min(all_trigs.attrs['foreground_time1'], all_trigs.attrs['foreground_time2'])

    background_time = int(maxtime / all_trigs.attrs['timeslide_interval']) * mintime
    coinc_time = float(all_trigs.attrs['coinc_time'])

    logging.info("Making mapping from FAN to the combined statistic")

    back_stat = all_trigs.stat[back_locs]
    fore_stat = all_trigs.stat[fore_locs]

    back_cnum, fnlouder = coinc.calculate_n_louder(back_stat, fore_stat, 
                                                   all_trigs.decimation_factor[back_locs])

    logging.info("Calculating ifar/fap values")

    f['background_h%s/ifar' % h_iterations] = sec_to_year(background_time / (back_cnum + 1))
    f.attrs['background_time_h%s' % h_iterations] = background_time
    f.attrs['foreground_time_h%s' % h_iterations] = coinc_time

    if fore_locs.sum() > 0:
        # Write ranking statistic to file just for downstream plotting code
        f['foreground_h%s/stat' % h_iterations] = fore_stat

        ifar = background_time / (fnlouder + 1)
        fap = 1 - numpy.exp(- coinc_time / ifar)

        f['foreground_h%s/ifar' % h_iterations] = sec_to_year(ifar)
        f['foreground_h%s/fap' % h_iterations] = fap

        # Update ifar and fap for other foreground triggers
        for i in range(0, len(ifar)):
            orig_fore_idx = numpy.where(orig_fore_stat == fore_stat[i])[0][0]
            f['foreground/ifar'][orig_fore_idx] = sec_to_year(ifar[i])
            f['foreground/fap'][orig_fore_idx] = fap[i]

# Write to file how many hierarchical removals were implemented.
f.attrs['hierarchical_removal_iterations'] = h_iterations

logging.info("Done") 
