#!/usr/bin/env python
import numpy, h5py, argparse, matplotlib, sys
matplotlib.use('Agg')
import pylab, pycbc.results, pycbc.version
from pycbc.events import veto
from pycbc.io import get_chisq_from_file_choice, chisq_choices, HFile
from pycbc.io import SingleDetTriggers

parser = argparse.ArgumentParser()
parser.add_argument('--trigger-file', help='Single ifo trigger file')
parser.add_argument('--version', action='version', version=pycbc.version.git_verbose_msg)
parser.add_argument('--veto-file', help='Optional, file of veto segments to remove triggers')
parser.add_argument('--segment-name', default=None, type=str,
                    help='Optional, name of segment list to use for vetoes')
parser.add_argument('--min-snr', type=float, help='Optional, Minimum SNR to plot')
parser.add_argument('--output-file')
parser.add_argument('--newsnr-contours', nargs='*', help="List of newsnr values to draw contours at.", default=[])
parser.add_argument('--chisq-choice', choices=chisq_choices,
                    default='traditional',
                    help='Which chisquared to plot. Default=traditional')
args = parser.parse_args()

f = h5py.File(args.trigger_file, 'r')
ifo = tuple(f.keys())[0]
f.close()
if args.min_snr:
    with HFile(args.trigger_file, 'r') as trig_file:
        n_triggers_orig = trig_file[f'{ifo}/snr'].size
        idx, _ = trig_file.select(
            lambda snr: snr >= args.min_snr,
            f'{ifo}/snr',
            return_indices=True
        )
        data_mask = numpy.zeros(n_triggers_orig, dtype=bool)
        data_mask[idx] = True
else:
    data_mask = None

trigs = SingleDetTriggers(
    args.trigger_file,
    None,
    args.veto_file,
    args.segment_name,
    None,
    ifo,
    premask=data_mask
)

snr = trigs['snr']
chisq = get_chisq_from_file_choice(trigs, args.chisq_choice)

def snr_from_chisq(chisq, newsnr, q=6.):
    snr = numpy.zeros(len(chisq)) + float(newsnr)
    ind = numpy.where(chisq > 1.)[0]
    snr[ind] = float(newsnr) / ( 0.5 * (1. + chisq[ind] ** (q/2.)) ) ** (-1./q)
    return snr

fig = pylab.figure(1)

r = numpy.logspace(numpy.log(chisq.min()), numpy.log(chisq.max()), 300)
for i, cval in enumerate(args.newsnr_contours):
    snrv = snr_from_chisq(r, cval)
    pylab.plot(snrv, r, color='black', lw=0.5)
    if i == 0:
        label = "$\\hat{\\rho} = %s$" % cval
    else:
        label = "$%s$" % cval
    try:
        label_pos_idx = numpy.where(snrv > snr.max() * 0.8)[0][0]
    except IndexError:
        label_pos_idx = 0
    pylab.text(snrv[label_pos_idx], r[label_pos_idx], label, fontsize=6,
               horizontalalignment='center', verticalalignment='center',
               bbox=dict(facecolor='white', lw=0, pad=0, alpha=0.9))

pylab.hexbin(snr, chisq, gridsize=300, xscale='log', yscale='log', lw=0.04,
             mincnt=1, norm=matplotlib.colors.LogNorm())

ax = pylab.gca()
pylab.grid()   
ax.set_xscale('log')
cb = pylab.colorbar() 
pylab.xlim(snr.min(), snr.max() * 1.1)
pylab.ylim(chisq.min(), chisq.max() * 1.1)
cb.set_label('Trigger Density')
pylab.xlabel('Signal-to-Noise Ratio')
pylab.ylabel('Reduced $\\chi^2$')
pycbc.results.save_fig_with_metadata(fig, args.output_file, 
     title="%s :SNR vs Reduced %s &chi;<sup>2</sup>" % (ifo, args.chisq_choice),
     caption="Distribution of SNR and %s &chi;&sup2; for single detector triggers: "
             "Black lines show contours of constant NewSNR." \
              %(args.chisq_choice,),
     cmd=' '.join(sys.argv),
     fig_kwds={'dpi':300})
