#!/usr/bin/python
""" Make table of the foreground coincident events
"""
import argparse, h5py, numpy, logging, sys, lal
import matplotlib
matplotlib.use('Agg')
import pylab, pycbc.results, pycbc.version
from scipy.special import erfc, erfinv

def sigma_from_p(p):
    return - erfinv(1 - (1 - p) * 2) * 2**0.5

def p_from_sigma(sig):
    return erfc((sig)/numpy.sqrt(2))/2.

def p_from_far(far, livetime):
    return 1 - numpy.exp(-far * livetime)

def _far_from_p(p, livetime, max_far):
    if p < 0.0001:
        lmbda = p + p**2./2.
    elif p == 1:
        return max_far
    else:
        lmbda = -numpy.log(1-p)
    far = lmbda/livetime
    if far > max_far:
        return max_far
    return far
far_from_p = numpy.vectorize(_far_from_p)

pylab.rc('text', usetex=True)
pylab.rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})

parser = argparse.ArgumentParser()
# General required options
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--trigger-file')
parser.add_argument('--verbose', action='count')
parser.add_argument('--output-file')
parser.add_argument('--trials-factor', type=int, default=1)
parser.add_argument('--not-cumulative', action='store_true')
parser.add_argument('--no-sigma-shading', action='store_true')
parser.add_argument('--fg-marker', default='^', help='Marker to use for the foreground triggers')
parser.add_argument('--closed-box', action='store_true',
      help="Make a closed box version that excludes foreground triggers")
parser.add_argument('--xmin', type=float, help='Set the minimum value of the x-axis')
parser.add_argument('--xmax', type=float, help='Set the maximum value of the x-axis')
parser.add_argument('--ymin', type=float, help='Set the minimum value of the y-axis (in units of 1/years)')
parser.add_argument('--ymax', type=float, help='Set the maximum value of the y-axis (in units of 1/years)')
args = parser.parse_args()

args.cumulative = not args.not_cumulative

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)
    
logging.info('Read in the data')
f = h5py.File(args.trigger_file, 'r')

foreground_livetime = f.attrs['foreground_time'] / lal.YRJUL_SI
try:
    cstat_fore = f['foreground/stat'][:]
    
    if args.cumulative:
        cstat_fore.sort()
        cstat_rate = numpy.arange(len(cstat_fore), 0, -1) / f.attrs['foreground_time'] * 3.15569e7
        fap_end = f['foreground/fap'][:].min() * args.trials_factor
    else:
        ssort = cstat_fore.argsort()
        cstat_rate = 1.0 / f['foreground/ifar'][:] * args.trials_factor
        cstat_rate = cstat_rate[ssort]
        cstat_fore = cstat_fore[ssort]
        cstat_fap = p_from_far(cstat_rate, foreground_livetime)
    
    logging.info('Found %s foreground triggers' % len(cstat_fore))
except:
    cstat_fore = None

if cstat_fore is not None and len(cstat_fore) == 0:
    cstat_fore = None

back_ifar = f['background/ifar'][:]
cstat_back = f['background/stat'][:]
back_sort = cstat_back.argsort()
cstat_back = cstat_back[back_sort]

far_back = 1.0 / back_ifar[back_sort]

if not args.cumulative:
    far_back *= args.trials_factor

fap_back = p_from_far(far_back, foreground_livetime)

# we'll use the background far for the ylimits
if args.ymin is None:
    plot_ymin = far_back.min()/10.
else:
    plot_ymin = args.ymin

if args.ymax is None:
    plot_ymax = far_back.max()
else:
    plot_ymax = args.ymax

logging.info('Found %s background (inclusive zerolag) triggers' % len(cstat_back))

back_ifar_exc = f['background_exc/ifar'][:]
cstat_back_exc = f['background_exc/stat'][:]
back_sort_exc = cstat_back_exc.argsort()
cstat_back_exc = cstat_back_exc[back_sort_exc]
far_back_exc = 1.0 / back_ifar_exc[back_sort_exc]

if not args.cumulative:
    far_back_exc *= args.trials_factor

logging.info('Found %s background (exclusive zerolag) triggers' % len(cstat_back_exc))

# we'll us the background for the xlimits if closed box,
# foreground otherwise
if args.xmin is not None:
    plot_xmin = args.xmin
elif args.closed_box or cstat_fore is None:
    plot_xmin = cstat_back_exc.min()
else:
    plot_xmin = cstat_fore.min()

if args.xmax is not None:
    plot_xmax = args.xmax
# otherwise, make the furthest point to the right be ~1/10 of the
# width of the plot from the right axis
else:
    if args.closed_box or cstat_fore is None:
        plot_xmax = cstat_back_exc.max()
    else:
        plot_xmax = max(cstat_back.max(), cstat_fore.max())
    plot_xmax += (plot_xmax - plot_xmin)/10.

fig = pylab.figure(1)
back_marker = 'x'
pylab.scatter(cstat_back_exc, far_back_exc, color='gray', marker=back_marker, s=10, label='Closed Box Background')

def pretty_num(num):
    exp = numpy.floor(numpy.log10(num))
    val = num / 10.0 ** exp
    return '%2.1fe%.0f' % (val, exp)

if not args.closed_box:
    pylab.scatter(cstat_back, far_back, color='black', marker=back_marker, s=10,
        label='Open Box Background')

    if cstat_fore is not None and len(cstat_fore):
        pylab.scatter(cstat_fore, cstat_rate, s=60, color='#ff6600',
                                  marker=args.fg_marker,
                                  label='Foreground', zorder=100,
                                  linewidth=0.5, edgecolors='white')


        if args.cumulative and not args.no_sigma_shading:
            end = sigma_from_p(fap_end)
            sigmas = numpy.array([end, 4, 3, 2, 1])
            top = cstat_rate.max()
            for i in range(len(sigmas) - 1):
                if sigmas[i+1] > end:
                    continue

                p1 = p_from_sigma(sigmas[i])
                p2 = p_from_sigma(sigmas[i+1])
                                
                #offset the shading so it is valid for the foreground beyond the first point
                if cstat_fore is not None and len(cstat_fore):
                    offset = numpy.interp(cstat_back, cstat_fore, cstat_rate - cstat_rate[-1])
                else:
                    offset = 0
                    
                low = far_back / p2 * args.trials_factor + offset
                high = far_back / p1 * args.trials_factor + offset
                pylab.fill_between(cstat_back, low, high,
                                   linewidth=0, color=pylab.cm.Blues(sigmas[i] / 5.0), zorder=-1)

                left = cstat_back[high < top].min()
                
                if sigmas[i] == end:
                    pylab.text(left, top*1.05, r"\textbf{%1.1f $\sigma$}"  % sigmas[i], fontsize=8)
                else:
                    pylab.text(left, top*1.05, r"\textbf{%1.0f $\sigma$}"  % sigmas[i], fontsize=8) 

        if args.not_cumulative:
            # add arrows to any points > the loudest background
            louder_pts = numpy.where(cstat_fore > cstat_back.max())[0] 
            for ii in louder_pts:
                r = cstat_fore[ii]
                arr_start = cstat_rate[ii]
                # make the arrow length 1/15 the height of the plot
                arr_end = arr_start * (plot_ymin / plot_ymax) ** (1./15)
                pylab.plot([r, r], [arr_start, arr_end], lw=2, color='black',
                    zorder=99)
                pylab.plot([r, r], [arr_start, arr_end], lw=2.6, color='white',
                    zorder=97)
                pylab.scatter([r], [arr_end], marker='v', c='black',
                    edgecolors='white', lw=0.5, s=40, zorder=98)
            
if not args.cumulative:
    # add second y-axis for probabilities, sigmas
    sigmas = numpy.arange(6)+1
    ax1 = pylab.gca()
    ax1.set_axis_bgcolor('none')
    ax2 = ax1.twinx()
    ax1.set_zorder(ax2.get_zorder()+1) # put axis1 on top
    pylab.sca(ax2)
    # where to stick the sigma lables; we'll put them 1/25th from the
    # right axis
    anntx = plot_xmax - (plot_xmax - plot_xmin)/25.
    sigps = p_from_sigma(sigmas)
    if not args.no_sigma_shading:
        for ii,p in enumerate(sigps[:-1]):
            nextp = sigps[ii+1]
            pylab.axhspan(far_from_p(nextp, foreground_livetime, far_back.max()),
                        far_from_p(p, foreground_livetime, far_back.max()),
                        linewidth=0,
                        color=pylab.cm.Blues(float(sigmas[ii+1]) / sigmas.size),
                        alpha=0.3, zorder=-1) 
            # add sigma label
            pylab.annotate('%1.0f$\sigma$' % sigmas[ii],
                           (anntx, far_from_p(p, foreground_livetime,
                            far_back.max())), zorder=100)


    pylab.sca(ax1)
    
if args.cumulative:
    pylab.ylabel('Cumulative Rate (yr$^{-1}$)')   
else:
    if args.trials_factor == 1:
        pylab.ylabel('False Alarm Rate (yr$^{-1}$)')
    elif args.trials_factor >= 1:
        pylab.ylabel('Combined False Alarm Rate (yr$^{-1}$)')
    ax2.set_ylabel('False Alarm Probability')

pylab.xlabel(r'Network Re-weighted SNR, $\hat{\rho}_c$')
pylab.yscale('log')
pylab.ylim(plot_ymin, plot_ymax)
pylab.xlim(plot_xmin, plot_xmax)

if args.cumulative:
    pylab.legend(loc="lower left")
    pylab.grid()
    figure_caption =  "Cumulative histogram of foreground and background triggers."
else:
    figure_caption =  "Mapping between the ranking statistic and false alarm rate."
    pylab.grid()
    pylab.legend(loc="upper right")
    ax2.set_yscale('log')
    ymin, ymax = ax1.get_ylim()
    
    ax2.set_ylim(ymin, ymax)
    # we'll put ticks at faps of 10^{-x}, where x is an integer
    # get the ymin, ymax in terms of probability
    pymin = p_from_far(ymin, foreground_livetime)
    pymax = p_from_far(ymax, foreground_livetime)
    # figure out the range of values we need to plot: this will be the
    # floor/ceil of the min/max
    # values
    tick_min = numpy.ceil(numpy.log10(pymin))
    tick_max = numpy.floor(numpy.log10(pymax))
    # tick ranges
    pticks = numpy.arange(tick_min, tick_max+1)
    # convert back to FAR
    fticks = far_from_p(10**pticks, foreground_livetime, far_back.max())
    ax2.set_yticks(fticks)
    # set the labels
    ax2.set_yticklabels(['$10^{%i}$' %(val) for val in pticks.astype(int)])
    # add minor ticks
    N_minor_per_decade = 10.
    pminorticks = numpy.zeros(N_minor_per_decade * (pticks.size-1))
    for ii,p in enumerate(10**pticks[:-1]):
        idx = ii*N_minor_per_decade
        pminorticks[idx:idx+N_minor_per_decade] = \
            p * numpy.arange(1,N_minor_per_decade+1).astype(float)
    fminorticks = far_from_p(pminorticks, foreground_livetime, far_back.max())
    ax2.set_yticks(fminorticks, minor=True)

figure_caption = figure_caption + "Orange diamonds (if present) represent triggers from the " \
"zero-lag (foreground) analysis. Solid crosses show " \
"the background inclusive of zerolag events, and grey crosses show the " \
"background constructed without triggers that are " \
"coincident in the zero-lag data.",

pycbc.results.save_fig_with_metadata(fig, args.output_file,
     title="%s bin, Cumulative Rate vs Rank" % f.attrs['name'] if 'name' in f.attrs else "FAR vs Rank", 
     caption=figure_caption,
     cmd=' '.join(sys.argv))
