#!/home/travis/build/pycbc-build/environment/bin/python
# Copyright (C) 2015 Christopher M. Biwer
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

import argparse
import h5py
import numpy
import sys
import lal
import matplotlib as mpl; mpl.use('Agg')
import pylab
import pycbc.results
import pycbc.version
from pycbc_glue import segments

def calculate_time_slide_duration(ifo1_segments, ifo2_segments, offset=0):
   ''' Returns the amount of coincident time between two segmentlists.
   '''

   # dertermine how much coincident time
   overlap = ifo1_segments.coalesce() & ifo2_segments.coalesce().shift(offset)
   duration = abs(overlap.coalesce())

   # unshift the shifted segmentlist
   ifo2_segments.shift(-offset)

   return duration

# parser command line
parser = argparse.ArgumentParser(usage='pycbc_page_ifar [--options]',
                    description='Plots a cumulative histogram of IFAR for'
                          'coincident foreground triggers and a subset of' 
                          'the coincident time slide triggers.')
parser.add_argument('--version', action='version',
                    version=pycbc.version.git_verbose_msg)
parser.add_argument('--trigger-file', type=str, required=True,
                    help='Path to coincident trigger HDF file.')
parser.add_argument('--output-file', type=str, required=True,
                    help='Path to output plot.')
parser.add_argument('--decimation-factor', type=int, required=True,
                    help='Decimation factor used in background estimation.' 
                          'Decimation factor means that every nth time' 
                          'slide kept all of its coincident triggers in' 
                          'the HDF file.')
parser.add_argument('--all-decimated-background-min-ifar', type=float,
                    default=0.01, metavar='VAL',
                    help='Threshold the IFARs of decimated background events '
                          'before combining them and rescaling to zerolag '
                          '(green line). Plotting all decimated background '
                          'can be memory intensive and is not necessary for '
                          'checking the background estimate at interesting '
                          'IFAR levels. Slide coincs with IFARs below VAL '
                          'will be cut, the green line will then be valid for '
                          'zerolag down to VAL*(zerolag coinc time/total '
                          'decimated slide time). Recommended value equal '
                          'to default value 0.01yr.')
parser.add_argument('--open-box', action='store_true', default=False,
                    help='Show the foreground triggers on the output plot. ')
opts = parser.parse_args()

# read file
fp = h5py.File(opts.trigger_file, 'r')

# get foreground IFAR values and cumulative number for each IFAR value
if opts.open_box:
    fore_ifar = fp['foreground/ifar'][:]
    fore_ifar.sort()
    fore_cumnum = numpy.arange(len(fore_ifar), 0, -1)

# get expected foreground IFAR values and cumulative number for each IFAR value
expected_ifar = numpy.logspace(-8, 2, num=100, endpoint=True, base=10.0)
expected_cumnum = fp.attrs['foreground_time'] / expected_ifar / lal.YRJUL_SI 

# get background timeslide IDs and IFAR values
back_tsid = fp['background/timeslide_id'][:]
back_ifar = fp['background/ifar'][:]

# get IFO segments
h1_segments = segments.segmentlist([segments.segment(s,e) for s,e in zip(fp['segments']['H1']['start'][:],fp['segments']['H1']['end'][:])])
l1_segments = segments.segmentlist([segments.segment(s,e) for s,e in zip(fp['segments']['L1']['start'][:],fp['segments']['L1']['end'][:])])

# make figure
fig = pylab.figure(1)

# get a unique list of timeslide_ids and loop over them
interval = fp.attrs['timeslide_interval']
ifo1, ifo2 = fp.attrs['detector_1'], fp.attrs['detector_2']
min_tsid = (fp['segments'][ifo1]['start'][:].min() - \
            fp['segments'][ifo2]['end'][:].max()) / interval
min_tsid = numpy.rint(min_tsid) // opts.decimation_factor
max_tsid = (fp['segments'][ifo1]['end'][:].max() - fp['segments'][ifo2]['start'][:].min()) / interval
max_tsid = numpy.rint(max_tsid) // opts.decimation_factor

tsids = numpy.arange(min_tsid, max_tsid, 1).astype(numpy.int64) * opts.decimation_factor

allbkg_dur = 0
allbkg_ifars = []
empty_slide_count = 0
plotted_slide_count = 0

for tsid in tsids:
    if tsid == 0:
        continue

    # find all triggers in this time slide
    ts_indx = numpy.where(back_tsid == tsid)

    # calculate the amount of coincident time in this time slide
    offset = tsid*fp.attrs['timeslide_interval']
    back_dur = calculate_time_slide_duration(h1_segments, l1_segments, offset=offset)
    if back_dur == 0:
        continue
    plotted_slide_count += 1
    allbkg_dur = allbkg_dur + back_dur
    # Limit how far back the decimated background gets plotted
    red_trigs = back_ifar[ts_indx][back_ifar[ts_indx] > \
                                        opts.all_decimated_background_min_ifar]
    allbkg_ifars.extend(red_trigs)

    # apply the correction factor for this time slide to its IFAR
    # you need a correction factor because the analyzed time of the time slide
    # is not the same as the analyzed time of the foreground
    ts_ifar = back_ifar[ts_indx] * ( fp.attrs['foreground_time'] / back_dur )
    ts_ifar.sort()

    # get the cumulative number of triggers in this time slide
    ts_cumnum = numpy.arange(len(ts_ifar), 0, -1)

    # plot the time slide triggers
    if len(ts_ifar) > 1:
        pylab.loglog(ts_ifar, ts_cumnum, color='gray', alpha=0.4)
    elif len(ts_ifar) == 1:
        pylab.plot(ts_ifar, ts_cumnum, color='gray', marker='.', alpha=0.4)
    else:
        empty_slide_count += 1

allbkg_ifars = numpy.array(allbkg_ifars)
allbkg_ifars = allbkg_ifars * (fp.attrs['foreground_time'] / allbkg_dur )
allbkg_ifars.sort()
allbkg_cumnum = numpy.arange(len(allbkg_ifars), 0, -1)
pylab.loglog(allbkg_ifars, allbkg_cumnum, color='green', linewidth=1.5,
             label="All decimated background")

# plot the expected background
pylab.loglog(expected_ifar, expected_cumnum, linestyle='--', linewidth=2, color='black', label='Expected Background')

# plot the counting error
error_plus = expected_cumnum + numpy.sqrt(expected_cumnum)
error_minus = expected_cumnum - numpy.sqrt(expected_cumnum)
error_minus = numpy.where(error_minus<=0, 1e-5, error_minus)
xs, ys = pylab.poly_between(expected_ifar, error_minus, error_plus)
pylab.fill(xs, ys, facecolor='y', alpha=0.4, label='$N^{1/2}$ Errors')

# plot the counting error
error_plus = expected_cumnum + 2 * numpy.sqrt(expected_cumnum)
error_minus = expected_cumnum - 2*numpy.sqrt(expected_cumnum)
error_minus = numpy.where(error_minus<=0, 1e-5, error_minus)
xs, ys = pylab.poly_between(expected_ifar, error_minus, error_plus)
pylab.fill(xs, ys, facecolor='y', alpha=0.2, label='$2N^{1/2}$ Errors')

# plot the foreground triggers
if opts.open_box:
    pylab.loglog(fore_ifar, fore_cumnum, linestyle='None', color='blue', marker='^', label='Foreground')

# format plot
if opts.open_box and len(fore_cumnum) > 100:
    # If we have > 100 foreground triggers, scale the plot around the foreground
    pylab.ylim(0.8, 1.1 * len(fore_cumnum))
    pylab.xlim(0.9 * min(fore_ifar))
elif len(allbkg_cumnum) > 0:
    # If we have < 100 foreground triggers (or this is closed box), scale the
    # plot around the cumulative background (the green line).
    pylab.ylim(0.8, 1.1 * len(allbkg_cumnum))
    pylab.xlim(0.9 * min(allbkg_ifars))
pylab.grid()
pylab.legend()
pylab.ylabel('Cumulative Number')
pylab.xlabel('Inverse False Alarm Rate (yr)')

# save
caption = 'This is a cumulative histogram of triggers. The blue triangles represent ' \
          + 'coincident foreground triggers. The dashed line represents the expected background ' \
          + 'given the analysis time. The shaded regions represent ' \
          + 'counting errors. The gray lines are time slides treated as zero lag, here there are %d time ' %(plotted_slide_count) \
          + 'slides plotted. Gray dots are time slides with only one event. ' \
          + '%d of the plotted slides have zero events. ' %(empty_slide_count) \
          + 'The green line represents all decimated time slides rescaled to the analysis time.'
pycbc.results.save_fig_with_metadata(fig, opts.output_file,
     title='Cumulative Number vs. IFAR',
     caption=caption,
     cmd=' '.join(sys.argv))
