#!/usr/bin/python

# Copyright 2016 Thomas Dent
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.

from __future__ import division

import sys, h5py
import argparse, logging

import copy, numpy as np

from pycbc import io, events
from pycbc.events import trigger_fits as trstats
import pycbc.version

#### DEFINITIONS AND FUNCTIONS ####

stat_dict = {
    "new_snr"       : events.newsnr,
    "effective_snr" : events.effsnr,
    "snr"           : lambda snr, rchisq : snr,
    "snronchi"      : lambda snr, rchisq : snr / (rchisq ** 0.5)
}

def get_stat(statchoice, snr, rchisq, fac):
    if fac is not None:
        if statchoice not in ["new_snr", "effective_snr"]:
            raise RuntimeError("Can't use --stat-factor with this statistic!")
        return stat_dict[statchoice](snr, rchisq, fac)
    else:
        return stat_dict[statchoice](snr, rchisq)

#### MAIN ####

parser = argparse.ArgumentParser(usage="",
    description="Perform maximum-likelihood fits of single inspiral trigger"
                " distributions to various functions")

parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
parser.add_argument("-V", "--verbose", action="store_true",
                    help="Print extra debugging information", default=False)
parser.add_argument("--trigger-file",
                    help="Input hdf5 file containing single triggers. "
                    "Required")
parser.add_argument("--template-file", default=None,
                    help="hdf file containing template parameters. Required")
parser.add_argument("--template-fraction-range", default="0/1",
                    help="Optional, analyze only part of template bank. "
                    "Format is PART/NUM_PARTS")
parser.add_argument("--veto-file", nargs='*',
                    help="File(s) in .xml format with veto segments to apply "
                    "to triggers before fitting")
parser.add_argument("--veto-segment-name", default=None,
                    help="Name(s) of veto segments to apply. Optional, if not "
                    "given all segments for a given ifo will be used")
parser.add_argument("--output", required=True,
                    help="Location for output file containing fit coefficients"
                    ". Required")
parser.add_argument("--ifo", required=True,
                    help="Ifo producing triggers to be fitted. Required")
parser.add_argument("--fit-function",
                    choices=["exponential", "rayleigh", "power"],
                    help="Functional form for the maximum likelihood fit")
parser.add_argument("--sngl-stat", default="new_snr",
                    choices=["snr", "snronchi", "effective_snr", "new_snr"],
                    help="Function of SNR and chisq to perform fits with")
parser.add_argument("--stat-factor", type=float,
                    help="Adjustable magic number used in some sngl "
                    "statistics. Values commonly used: 6 for new_snr, 250 "
                    "or 50 for effective_snr")
parser.add_argument("--stat-threshold", type=float,
                    help="Only fit triggers with statistic value above this "
                    "threshold.  Required.  Typically 6-6.5")
parser.add_argument("--prune-param", # don't hardcode choices yet?
                    #choices=[
                    #    'mchirp',
                    #    'mtotal',
                    #    'template_duration'] \
                    #+ pnutils.named_frequency_cutoffs.keys()
                    help="Parameter to define bins for 'pruning' loud triggers"
                    " to make the fit insensitive to signals and outliers. "
                    "Choose from mchirp, mtotal, template_duration or a named "
                    "frequency cutoff in pnutils or a frequency function in "
                    "LALSimulation")
parser.add_argument("--prune-bins", type=int,
                    help="Number of bins to divide bank into when pruning")
parser.add_argument("--prune-number", type=int,
                    help="Number of loudest events to prune in each bin")
parser.add_argument("--log-param", action='store_true',
                    help="Bin in the log of prune-param")
parser.add_argument("--f-lower", default=-1.,
                    help="Starting frequency for calculating template "
                    "duration, required if this is the fit parameter")
# FIXME : Would like to have choice of SEOBNRv2 or PhenD duration formula
parser.add_argument("--min-duration", default=0.,
                    help="Fudge factor for templates with tiny or negative "
                    "values of template_duration: add to duration values "
                    "before fitting. Units seconds")

args = parser.parse_args()

if (args.prune_param or args.prune_bins or args.prune_number) and not \
   (args.prune_param and args.prune_bins and args.prune_number):
    raise RuntimeError("To prune, need to specify param, number of bins and "
                       "nonzero number to prune in each bin!")

if args.verbose:
    log_level = logging.DEBUG
else:
    log_level = logging.WARN
logging.basicConfig(format='%(asctime)s : %(message)s', level=log_level)

# initialize result storage
tids = []
counts = []
counts_above = []
fits = []

logging.info('Fitting above threshold %f' % args.stat_threshold)

logging.info('Opening trigger file: %s' % args.trigger_file)
trigf = h5py.File(args.trigger_file, 'r')
logging.info('Opening template file: %s' % args.template_file)
templatef = h5py.File(args.template_file, 'r')

# get the stat values
chisq = trigf[args.ifo+'/chisq'][:]
chisq_dof = trigf[args.ifo+'/chisq_dof'][:]
rchisq = chisq / (2 * chisq_dof - 2)
del chisq
del chisq_dof
snr = trigf[args.ifo+'/snr'][:]
logging.info('Calculating stat values')
stat = get_stat(args.sngl_stat, snr, rchisq, args.stat_factor)
del snr
del rchisq

# do first thresholding operation to reduce trigger numbers
abovethresh = stat >= args.stat_threshold
stat = stat[abovethresh]
tid = trigf[args.ifo+'/template_id'][:][abovethresh]
time = trigf[args.ifo+'/end_time'][:][abovethresh]
logging.info('%i trigs left after thresholding' % len(stat))

# now do vetoing
retain, junk = events.veto.indices_outside_segments(time, args.veto_file,
                             ifo=args.ifo, segment_name=args.veto_segment_name)
stat = stat[retain]
tid = tid[retain]
time = time[retain]
logging.info('%i trigs left after vetoing with %s' %
                                                   (len(stat), args.veto_file))

### Functions for doing the pruning (removal of trigs at loudest times)

def get_masses(args, tid):
    m1 = templatef['mass1'][:][tid]
    m2 = templatef['mass2'][:][tid]
    s1z = templatef['spin1z'][:][tid]
    s2z = templatef['spin2z'][:][tid]
    return m1, m2, s1z, s2z

def get_param(args, m1, m2, s1z, s2z):
    if args.prune_param == 'mchirp':
        parvals, _ = pnutils.mass1_mass2_to_mchirp_eta(m1, m2)
    elif args.prune_param == 'mtotal':
        parvals = m1 + m2
    elif args.prune_param == 'template_duration':
        parvals = pnutils.get_seobnrrom_duration(m1, m2, s1z, s2z, f_low=args.f_lower)
        parvals += args.min_duration
    elif args.prune_param in pnutils.named_frequency_cutoffs.keys():
        parvals = pnutils.frequency_cutoff_from_name(args.fit_param, m1, m2, s1z, s2z)
    else:
        # try asking for a LALSimulation frequency function
        parvals = pnutils.get_freq(args.prune_param, m1, m2, s1z, s2z)
    return parvals

# get min and max param values in bank
num_templates = len(templatef['template_hash'])
if args.prune_param:
    logging.info('Getting min and max param values')
    pars = get_param(args, templatef['mass1'][:], templatef['mass2'][:],
                     templatef['spin1z'][:], templatef['spin2z'][:])
    minpar = min(pars)
    maxpar = max(pars)
    del pars
    logging.info('%f %f' % (minpar, maxpar))

def which_bin(par, minpar, maxpar, nbins, log=False):
    """
    Returns the bin index in which the parameter value belongs (from 0 through
    nbins-1) when dividing the range between minpar and maxpar equally into
    nbins bins.
    """
    assert (par >= minpar and par <= maxpar)
    if log:
        par, minpar, maxpar = np.log(par), np.log(minpar), np.log(maxpar)
    # par lies some fraction of the way between min and max
    frac = float(par - minpar) / float(maxpar - minpar)
    # binind then lies between 0 and nbins - 1
    binind = int(frac * nbins)
    # except for a corner case
    if par == maxpar:
        binind = nbins - 1
    return binind

if args.prune_param:
    # hard-coded time window of 0.1s
    args.prune_window = 0.1
    # initialize bin storage
    prunedtimes = {}
    for i in range(args.prune_bins):
        prunedtimes[i] = []

    # keep a record of the triggers if all successive loudest events were to
    # be pruned
    statpruneall = copy.deepcopy(stat)
    tidpruneall = copy.deepcopy(tid)
    timepruneall = copy.deepcopy(time)

    # many trials may be required to prune in 'quieter' bins
    for j in range(1000):
        # are all the bins full already?
        numpruned = sum([len(prunedtimes[i]) for i in range(args.prune_bins)])
        if numpruned == args.prune_bins * args.prune_number:
            logging.info('Finished pruning!')
            break
        if numpruned > args.prune_bins * args.prune_number:
            logging.error('Uh-oh, we pruned too many things .. %i, to be '
                          'precise' % numpruned)
            raise RuntimeError
        loudest = np.argmax(statpruneall)
        lstat = statpruneall[loudest]
        ltid = tidpruneall[loudest]
        ltime = timepruneall[loudest]
        m1, m2, s1z, s2z = get_masses(args, ltid)
        lbin = which_bin(get_param(args, m1, m2, s1z, s2z), minpar, maxpar,
                         args.prune_bins, log=args.log_param)
        # is the bin where the loudest trigger lives full already?
        if len(prunedtimes[lbin]) == args.prune_number:
            logging.info('%i - Bin %i full, not pruning event with stat %f at time '
                         '%.3f' % (j, lbin, lstat, ltime))
            # prune the reference trigger array
            retain = abs(timepruneall - ltime) > args.prune_window
            statpruneall = statpruneall[retain]
            tidpruneall = tidpruneall[retain]
            timepruneall = timepruneall[retain]
            del retain
            continue
        else:
            logging.info('Pruning event with stat %f at time %.3f in bin %i' %
                         (lstat, ltime, lbin))
            # now do the pruning
            retain = abs(time - ltime) > args.prune_window
            logging.info('%i trigs before pruning' % len(stat))
            stat = stat[retain]
            logging.info('%i trigs remain' % len(stat))
            tid = tid[retain]
            time = time[retain]
            # also for the reference trig arrays
            retain = abs(timepruneall - ltime) > args.prune_window
            statpruneall = statpruneall[retain]
            tidpruneall = tidpruneall[retain]
            timepruneall = timepruneall[retain]
            # record the time
            prunedtimes[lbin].append(ltime)
            del retain
    del statpruneall
    del tidpruneall
    del timepruneall
    logging.info('%i trigs remain after pruning loop' % len(stat))

# parse template range
rangestr = args.template_fraction_range
part = int(rangestr.split('/')[0])
pieces = int(rangestr.split('/')[1])
tmin = int(num_templates / float(pieces) * part)
tmax = int(num_templates / float(pieces) * (part + 1))
trange = range(tmin, tmax)

for tnum in trange:
    stat_in_template = stat[tid == tnum]
    count_above = len(stat_in_template)
    if count_above == 0:
        # 'stupid' value to indicate no data, shouldn't hurt if 1/alpha is averaged
        alpha = -100.
    else:
        alpha, sig_alpha = trstats.fit_above_thresh(
                      args.fit_function, stat_in_template, args.stat_threshold)
    tids.append(tnum)
    counts_above.append(count_above)
    fits.append(alpha)
    if (tnum % 100 == 0): logging.info('Fitted template %i / %i' % (tnum - tmin, tmax - tmin))

outfile = h5py.File(args.output, 'w')
# store template-dependent fit output
outfile.create_dataset("template_id", data=trange)
outfile.create_dataset("count_above_thresh", data=counts_above)
outfile.create_dataset("fit_coeff", data=fits)
# add some metadata
outfile.attrs.create("ifo", data=args.ifo)
outfile.attrs.create("fit_function", data=args.fit_function)
outfile.attrs.create("sngl_stat", data=args.sngl_stat)
if args.stat_factor:
    outfile.attrs.create("stat_factor", data=args.stat_factor)
outfile.attrs.create("stat_threshold", data=args.stat_threshold)

outfile.close()
logging.info('Done!')
