#!/usr/bin/python
""" Make table of the foreground coincident events
"""
import argparse, h5py, numpy, logging, sys, lal
import matplotlib
matplotlib.use('Agg')
import pylab, pycbc.results, pycbc.version
from scipy.special import erf, erfinv

pylab.rc('text', usetex=True)

def sigma_from_p(p):
    return - erfinv(1 - (1 - p) * 2) * 2**0.5

def p_from_sigma(sig):
    return 1 - (1 - erf(sig / 2**0.5)) / 2
    
pylab.rc('text', usetex=True)

parser = argparse.ArgumentParser()
# General required options
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--trigger-file')
parser.add_argument('--verbose', action='count')
parser.add_argument('--output-file')
parser.add_argument('--bin-size', type=float, default=0.1)
parser.add_argument('--x-min', type=float, default=8.0)
parser.add_argument('--trials-factor', type=int, default=1)
parser.add_argument('--cumulative', action='store_true')
parser.add_argument('--closed-box', action='store_true',
      help="Make a closed box version that excludes foreground triggers")
args = parser.parse_args()

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)
    
logging.info('Read in the data')
f = h5py.File(args.trigger_file, 'r')

if args.closed_box:
    fstat = None
else:
    try:
        fstat = f['foreground/stat'][:]
        fstat.sort()
    except:
        fstat = None
    if len(fstat) == 0:
        fstat = None


dec = f['background/decimation_factor'][:]
bstat = f['background/stat'][:]
fap = 1 - numpy.exp(- f.attrs['foreground_time'] / f['background/ifar'][:] / lal.YRJUL_SI)

s = bstat.argsort()
dec, bstat, fap = dec[s], bstat[s], fap[s]
logging.info('Found %s background (inclusive zerolag) triggers' % len(bstat))

dec_exc = f['background_exc/decimation_factor'][:]
bstat_exc = f['background_exc/stat'][:]
s = bstat_exc.argsort()
dec_exc, bstat_exc = dec_exc[s], bstat_exc[s]


logging.info('Found %s background (exclusive zerolag) triggers' % len(bstat_exc))

fig = pylab.figure()

if fstat is not None:
    minimum = fstat.min()
    maximum = max(fstat.max(), bstat.max())
elif args.closed_box:
    minimum = bstat_exc.min()
    maximum = bstat_exc.max()
else:
    minimum = bstat.min()
    maximum = bstat.max()

bins = numpy.arange(minimum, maximum + args.bin_size, args.bin_size)   

# plot background minus foreground
pylab.hist(bstat_exc, bins=bins, histtype='step', 
                      linewidth=2,
                      color='grey', log=True,
                      label= 'Background Uncorrelated with Foreground',
                      weights=dec_exc / f.attrs['background_time_exc'] * lal.YRJUL_SI)

# plot full background
if not args.closed_box:
    pylab.hist(bstat, bins=bins, histtype='step',
                  linewidth=2,
                  color='black', log=True, 
                  label='Full Background',
                  weights=dec / f.attrs['background_time'] * lal.YRJUL_SI)

if fstat is not None and not args.closed_box:
    le, re = bins[:-1], bins[1:]
    left = numpy.searchsorted(fstat, le)
    right = numpy.searchsorted(fstat, re)
    count = (right - left) / f.attrs['foreground_time'] * lal.YRJUL_SI
    pylab.errorbar(bins[:-1] + args.bin_size / 2, count,
                   xerr=args.bin_size/2, label='Foreground',
                   mec='none', fmt='o', ms=1, capthick=0, elinewidth=4,  color='#ff6600')

pylab.xlabel('Weighted Network SNR, $\\rho_c$ (Bin Size = %.2f)' % args.bin_size)
pylab.ylabel('Trigger Rate (yr$^{-1})$')
pylab.xlim(xmin=args.x_min)
pylab.ylim(ymin=0.5 / f.attrs['background_time_exc'] * lal.YRJUL_SI)
pylab.grid()
leg = pylab.legend(loc='upper center', fontsize=9)

end = sigma_from_p(fap.min() * args.trials_factor)

if not args.closed_box:
    sigmas = [1, 2, 3, 4, end]
    for ii, sig in enumerate(sigmas[:-1]):
        next_sig = sigmas[ii + 1]

        p1 = 1 - p_from_sigma(sig) / args.trials_factor
        p2 = 1 - p_from_sigma(next_sig) / args.trials_factor        

        x1 = numpy.searchsorted(fap[::-1], p1)
        x2 = numpy.searchsorted(fap[::-1], p2)
        
        if x1 == x2:
            continue
        
        ymin, ymax = pylab.gca().get_ylim()
        try:
            x = [bstat[::-1][x1], bstat[::-1][x2]]
        except IndexError:
            break
        pylab.fill_between(x, ymin, ymax, color=pylab.cm.Blues(next_sig / 8.0), zorder=-1)
        
        if next_sig == end:
            next_sig = '%.1f' % next_sig
            
        pylab.text(bstat[::-1][x2] - .1, ymax, r"$%s \sigma$" % next_sig, fontsize=10)

ax1 =  pylab.gca()
ax2 = ax1.twinx()
fac = f.attrs['foreground_time'] / lal.YRJUL_SI
ymin = ax1.get_ylim()[0] * fac
ymax =  ax1.get_ylim()[1] * fac

ax2.set_ylim(ymin=ymin, ymax=ymax)
ax2.set_yscale('log')
ax2.set_ylabel('Number per Experiment')

pycbc.results.save_fig_with_metadata(fig, args.output_file,
     title="%s bin, Count vs Rank" % f.attrs['name'] if 'name' in f.attrs else "Count vs Rank", 
     caption="Histogram of the FAR vs the ranking statistic in the search.",
     cmd=' '.join(sys.argv))
             
