#!/usr/bin/python
''' Make plots of recovered injection parameters
'''
import h5py, numpy, logging, argparse, sys, matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plot

import pycbc.version, pycbc.detector
from pycbc import pnutils, results

parser = argparse.ArgumentParser()
parser.add_argument("--verbose", action="count")
parser.add_argument("--version", action="version",
                    version=pycbc.version.git_verbose_msg)
parser.add_argument("--injection-file", required=True,
                    help="hdf injection file containing found injections. "
                         "Required.")
parser.add_argument("--bank-file", required=True,
                    help="hdf bank file containing template parameters. "
                         "Required")
parser.add_argument("--trigger-file",
                    help="hdf trigger merge file containing single triggers. "
                         "Required for some parameters, e.g. end_time.")
parser.add_argument("--detector", required=True,
                    help="Detector where injections are recovered. Required")
parser.add_argument("--min-ifar", default=0,
                    help="Minimum IFAR for which to plot injections. Units:years")
parser.add_argument("--error-param", required=True,
                    help="Property to be compared between injected and recovered "
                         "event, ex. 'mchirp'. Required")
parser.add_argument("--x-param",
                    help="Injected parameter to plot on x-axis. Required if "
                         "--plot-type is err_v_param or fracerr_v_param.")
parser.add_argument("--log-x",
                    help="Use a logarithmic x-axis.")
parser.add_argument("--log-y",
                    help="Use a logarithmic y-axis.")
parser.add_argument("--plot-type", choices=["scatter", "err", "fracerr",
                    "err_v_param", "fracerr_v_param", "errhist",
                    "fracerrhist"], default="scatter",
                    help="Type of plot to show accuracy of recovery. "
                         "Default='scatter'")
parser.add_argument("--gradient-far", action="store_true",
                    help="Show FAR of found injections as a gradient")
parser.add_argument("--output-file", required=True)
args = parser.parse_args()

if args.x_param is not None and \
                     args.plot_type not in ["err_v_param", "fracerr_v_param"]:
    raise RuntimeError("Can't use an --x-param with this plot type!")
if args.plot_type in ["err_v_param", "fracerr_v_param"] and \
                                                          args.x_param is None:
    raise RuntimeError("Need an --x-param to plot errors against!")

if args.verbose:
    log_level = logging.INFO
    logging.basicConfig(format="%(asctime)s : %(message)s", level=log_level)

logging.info("Reading data...")
injs = h5py.File(args.injection_file, "r")
bank = h5py.File(args.bank_file, "r")
trig = h5py.File(args.trigger_file, "r") if args.trigger_file else None

# First check the detector
ifo = args.detector
dettuples = [(val,name) for (name,val) in injs.attrs.items()
                                                         if "detector" in name]
dets = [t[0] for t in dettuples]
if ifo not in dets:
    raise RuntimeError("Can't find detector %s in injection file!" % ifo)
site = ifo[0].lower()

# Determine which injections are found
found = injs["found_after_vetoes/injection_index"][:]
ifar_found = numpy.array(injs["found_after_vetoes/ifar"][:])
# above_min_ifar is a Boolean slice to apply to arrays containing indices of
# found injections
above_min_ifar = ifar_found > float(args.min_ifar)
found = found[above_min_ifar]
ifar_found = ifar_found[above_min_ifar]

def get_inj_param(injfile, param, args):
    """
    Translates some popular injection parameters into functions that calculate
    them from an hdf found injection file

    Parameters
    ----------
    injfile: hdf5 File object
        Injection file of format known to ANitz (DOCUMENTME)
    param: string
        Parameter to be calculated for the injected signals
    args: argparse.Namespace object
        Parsed options given to the code

    Returns
    -------
    [return value]: NumPy array of floats
        The calculated parameter values
    """
    det = pycbc.detector.Detector(args.detector)
    _time_delay = lambda dec, ra, t: \
          det.time_delay_from_earth_center(dec, ra, t)
    time_delay = numpy.vectorize(_time_delay)

    inj = injfile["injections"]
    if param in inj.keys():
        return injs["injections/"+param]
    inj_param_dict = {
      "mtotal" : inj['mass1'][:] + inj['mass2'][:],
      "mchirp" : pnutils.mass1_mass2_to_mchirp_eta(inj['mass1'][:],
                 inj['mass2'][:])[0],
      "eta"    : pnutils.mass1_mass2_to_mchirp_eta(inj['mass1'][:],
                 inj['mass2'][:])[1],
      "effective_spin" : pnutils.phenomb_chi(inj['mass1'][:], inj['mass2'][:],
                 inj['spin1z'][:], inj['spin2z'][:]),
      "end_time_"+args.detector[0].lower() : inj['end_time'][:] + \
                 time_delay(inj['longitude'][:], inj['latitude'][:],
                 inj['end_time'][:]),
    }
    return inj_param_dict[param]

# Get the injected param values
# need to hack end time and eff dist as different between detectors
param = args.error_param
if args.error_param == "end_time":
    param = "end_time_"+site
if args.error_param == "effective_distance":
    param = "eff_dist_"+site
# calculate the parameter for all injections, then select those found and 
# above the IFAR threshold
inj_params = get_inj_param(injs, param, args)[found]

# Get the injected x-axis values
if args.x_param:
    xparam = args.x_param
    if args.x_param == "end_time":
        xparam = "end_time_"+site
    if args.x_param == "effective_distance":
        xparam = "eff_dist_"+site
    inj_xparams = get_inj_param(injs, xparam, args)[found]
else:
    # will need this to be a string when defining axis labels even if not used
    args.x_param = "badger"

def get_found_param(injfile, bankfile, trigfile, param, args):
    """
    Translates some popular trigger parameters into functions that calculate
    them from an hdf found injection file

    Parameters
    ----------
    injfile: hdf5 File object
        Injection file of format known to ANitz (DOCUMENTME)
    bankfile: hdf5 File object
        Template bank file
    trigfile: hdf5 File object
        Single-detector trigger file
    param: string
        Parameter to be calculated for the recovered triggers
    args: argparse.Namespace object
        Parsed options given to the code

    Returns
    -------
    [return value]: NumPy array of floats
        The calculated parameter values
    """
    foundtmp = injfile["found_after_vetoes/template_id"]
    # get the internal name of the ifo in the injection file, eg "detector_1"
    ifolabel = [name for name,val in injfile.attrs.items()
                                                    if val == args.detector][0]
    # and the integer from that name ...
    ifoid = ifolabel[-1]
    if args.trigger_file:
        foundtrg = injfile["found_after_vetoes/trigger_id"+ifoid]
    if param in bankfile.keys():
        return bankfile[param][:][foundtmp]
    elif trigfile is not None and param in trigfile[args.detector].keys():
        return trigfile[args.detector][param][:][foundtrg]
    else:
        b = bankfile
        found_param_dict = {
          "mtotal" : (b['mass1'][:] + b['mass2'][:])[foundtmp],
          "mchirp" : pnutils.mass1_mass2_to_mchirp_eta(b['mass1'][:],
                     b['mass2'][:])[0][foundtmp],
          "eta"    : pnutils.mass1_mass2_to_mchirp_eta(b['mass1'][:],
                     b['mass2'][:])[1][foundtmp],
          "effective_spin" : pnutils.phenomb_chi(b['mass1'][:], b['mass2'][:],
                             b['spin1z'][:], b['spin2z'][:])[foundtmp]
        }
    return found_param_dict[param]

# Get the recovered values
# unhack the parameter name as triggers are stored by ifo
param = args.error_param
# calculate parameters for all recovered injection triggers, then select
# only those above the IFAR threshold
rec_params = get_found_param(injs, bank, trig, param, args)[above_min_ifar]

# calculate needed values
if "err" in args.plot_type:
    diff = rec_params - inj_params
if "fracerr" in args.plot_type:
    reldiff = diff / inj_params

x_dict = {
  "scatter" : "inj_params",
  "err" : "inj_params",
  "fracerr" : "inj_params",
  "err_v_param" : "inj_xparams",
  "fracerr_v_param" : "inj_xparams",
  "errhist" : "diff",
  "fracerrhist" : "reldiff"
}
xlabels = {
  "scatter" : "Injected "+args.error_param,
  "err" : "Injected "+args.error_param,
  "fracerr" : "Injected "+args.error_param,
  "err_v_param" : "Injected "+args.x_param,
  "fracerr_v_param" : "Injected "+args.x_param,
  "errhist" : "Error (rec-inj) in "+args.error_param,
  "fracerrhist" : "Fractional error (rec-inj)/inj in "+args.error_param
}

y_dict = {
  "scatter" : "rec_params",
  "err" : "diff",
  "fracerr" : "reldiff",
  "err_v_param" : "diff",
  "fracerr_v_param" : "reldiff",
  "errhist" : "None",
  "fracerrhist" : "None"
}
ylabels = {
  "scatter" : "Recovered "+args.error_param,
  "err" : "Error (rec-inj) in "+args.error_param,
  "fracerr" : "Fractional error (rec-inj)/inj in "+args.error_param,
  "err_v_param" : "Error (rec-inj) in "+args.error_param,
  "fracerr_v_param" : "Fractional error (rec-inj)/inj in "+args.error_param,
  "errhist" : "Number of injections",
  "fracerrhist" : "Number of injections"
}

xvals = eval(x_dict[args.plot_type])
yvals = eval(y_dict[args.plot_type])

logging.info("Plotting...")
fig = plot.figure()

if "hist" not in args.plot_type:
    if not args.gradient_far:
        # make a scatter with crude colour coding
        color = numpy.zeros(len(found))
        ten = numpy.where(ifar_found > 10)[0]
        hundred = numpy.where(ifar_found > 100)[0]
        thousand = numpy.where(ifar_found > 1000)[0]
        color[hundred] = 0.5
        color[thousand] = 1.0
        caption = ("Found injections: blue circles are found with IFAR<100yr, "
                   "green are IFAR<1000yr, red are IFAR >=1000yr")
        points = plot.scatter(xvals, yvals, c=color, linewidth=0, vmin=0,
                              vmax=1, s=12, marker="o", label="found",
                              alpha=0.6)

    else:
        # make a pretty rainbow coloured plot
        color = 1.0 / ifar_found
        # sort so quiet found is on top
        csort = color.argsort()
        xvals = xvals[csort]
        yvals = yvals[csort]
        color = color[csort]
        caption = ("Found injections: color indicates the estimated false "
                   "alarm rate")
        points = plot.scatter(xvals, yvals, c=color, linewidth=0,
                              norm=matplotlib.colors.LogNorm(),
                              s=16, marker="o", label="found")
        plot.subplots_adjust(right=0.99)
        try:
            c = plot.colorbar()
            c.set_label("False Alarm Rate (yr$^{-1}$)")
        except TypeError:
            # Can't make colorbar if no quiet found injections
            if len(fvals):
                raise
else:
    # it's a histogram!
    plot.hist(xvals, bins=int(len(xvals)**0.5), histtype="step", label="found")
    caption = ("Found injections")

ax = plot.gca()
if args.log_x:
    ax.set_xscale("log")
if args.log_y:
    ax.set_yscale("log")
plot.xlabel(xlabels[args.plot_type], size='large')
plot.ylabel(ylabels[args.plot_type], size='large')
plot.grid(True)

fig_kwds = {}
if ".png" in args.output_file:
    fig_kwds["dpi"] = 150

if (".html" in args.output_file):
    plot.subplots_adjust(left=0.1, right=0.8, top=0.9, bottom=0.1)
    import mpld3, mpld3.plugins, mpld3.utils
    mpld3.plugins.connect(fig, mpld3.plugins.MousePosition(fmt=".5g"))
    legend =  mpld3.plugins.InteractiveLegendPlugin([points], ["found"],
                                                    alpha_unsel=0.1)
    mpld3.plugins.connect(fig, legend)

results.save_fig_with_metadata(fig, args.output_file,
                     fig_kwds=fig_kwds,
                     title="Injection %s recovery" % (args.error_param),
                     cmd=" ".join(sys.argv),
                     caption=caption)
logging.info("Done!")
