#!/bin/env python
"""Rewrite statmap file and rerank candidates using the statistic values
generated from the followup of candidates.
"""
import numpy, argparse, pycbc
from pycbc.io import HFile
from pycbc.conversions import sec_to_year
from pycbc.events import significance
from shutil import copyfile

parser = argparse.ArgumentParser()
pycbc.add_common_pycbc_options(parser)
parser.add_argument('--stat-files', nargs='+',
    help="Statistic files produced by candidate followup codes")
parser.add_argument('--followup-file',
    help="File containing the candidate times which were analyzed")
parser.add_argument('--statmap-file',
    help="The statmap file containing the candidates to rerank")
significance.insert_significance_option_group(parser)
parser.add_argument('--ranking-file',
    help="Provided only for injection sets, use this file to provide the "
         "background to rank candidate significance")
parser.add_argument('--output-file')

args = parser.parse_args()
pycbc.init_logging(args.verbose)

significance.check_significance_options(args, parser)

# Reconstruct the full set of statistic values for our candidates
f = HFile(args.followup_file, 'r')
num = len(f['offsets']) # Number of followups done

# Mapping between the followups done and the original candidate list
# May be shorter due to duplicates in the original set (which combines
# background / background_exc, etc
inv = f['inverse'][:]
stats = numpy.zeros(num)
sections = f.attrs['sections']

values = []
starts = []
for fname in args.stat_files:
    f = HFile(fname, 'r')
    s = f.attrs['start_index']
    v = f['stat'][:]
    stride = f.attrs['stride']
    stats[s::stride] = v
stats = stats[inv]

# copy statmap file to output since we'll
# only make a few modifications
copyfile(args.statmap_file, args.output_file)
o = HFile(args.output_file)

ifo_combo = o.attrs['ifos'].replace(' ','')

significance_dict = significance.digest_significance_options([ifo_combo],
                                                             args)

# Update the statistic values
for sec in sections:
    # New stats for this section
    nsize = len(o[sec]['stat'])
    o[sec]['stat'][...] = stats[:nsize]
    # use for next section
    stats = stats[nsize:]

background_time = o.attrs['background_time']
coinc_time = o.attrs['foreground_time']
coinc_time_exc = o.attrs['foreground_time_exc']
background_time_exc = o.attrs['background_time_exc']

# Injection run
if args.ranking_file:
    f = HFile(args.ranking_file, 'r')
    fstat = o['foreground/stat'][:]
    backstat = f['background_exc/stat'][:]
    dec = f['background_exc/decimation_factor'][:]

    bnum, fnum = significance.get_n_louder(
        backstat,
        fstat,
        dec,
        **significance_dict[ifo_combo])

    ifar = background_time / (fnum + 1)
    fap = 1 - numpy.exp(- coinc_time / ifar)
    o['foreground/ifar'][...] = sec_to_year(ifar)
    o['foreground/fap'][...] = fap

    o['foreground/ifar_exc'][...] = o['foreground/ifar'][:]
    o['foreground/fap_exc'][...] = o['foreground/fap'][:]

# full data run
else:
    fstat = o['foreground/stat'][:]
    backstat = o['background/stat'][:]
    dec = o['background/decimation_factor'][:]
    dec_exc = o['background_exc/decimation_factor'][:]
    backstat_exc = o['background_exc/stat'][:]

    bnum, fnum = significance.get_n_louder(
        backstat,
        fstat,
        dec,
        **significance_dict[ifo_combo])

    bnum_exc, fnum_exc = significance.get_n_louder(
        backstat_exc,
        fstat,
        dec_exc,
        **significance_dict[ifo_combo])

    o['background/ifar'][...] = sec_to_year(background_time / (bnum + 1))
    o['background_exc/ifar'][...] = sec_to_year(background_time_exc / (bnum_exc + 1))

    ifar = background_time / (fnum + 1)
    fap = 1 - numpy.exp(- coinc_time / ifar)
    o['foreground/ifar'][...] = sec_to_year(ifar)
    o['foreground/fap'][...] = fap

    ifar_exc = background_time_exc / (fnum_exc + 1)
    fap_exc = 1 - numpy.exp(- coinc_time_exc / ifar_exc)
    o['foreground/ifar_exc'][...] = sec_to_year(ifar_exc)
    o['foreground/fap_exc'][...] = fap_exc
