#!/bin/env python

import argparse
import json
from fractions import Fraction

import bilby
import h5py
import numpy as np

import pygwb.pe as pe
from pygwb.baseline import Baseline
from pygwb.detector import Interferometer

def load_npz(fname_path):
    data = np.load(fname_path)
    ptEst_ff = np.real(data["point_estimate_spectrum"])
    sigma_ff = data["sigma_spectrum"]
    goodInds = np.logical_and(np.invert(np.isnan(ptEst_ff)), ptEst_ff!=0, sigma_ff!=0)
    if "frequencies" not in data.keys():
        freqs = None
    else:
        freqs = data["frequencies"]
        freqs = freqs[goodInds]
    Y_f = ptEst_ff[goodInds]
    sigma_f = sigma_ff[goodInds]
    return (freqs,Y_f,sigma_f)

def load_hdf5(fname_path):
    hf = h5py.File(fname_path, "r")
    point_estimate_spectrum = np.array(hf.get("point_estimate_spectrum"))
    sigma_spectrum = np.array(hf.get("sigma_spectrum"))
    if "freqs" not in hf.keys():
        freqs = None
    else:
        freqs = np.array(hf.get("freqs"))
        
    return (freqs, point_estimate_spectrum, sigma_spectrum)

def load_npz(fname_path):
    data = np.load(fname_path)
    ptEst_ff = np.real(data["point_estimate_spectrum"])
    sigma_ff = data["sigma_spectrum"]
    goodInds = np.logical_and(np.invert(np.isnan(ptEst_ff)), ptEst_ff!=0, sigma_ff!=0)
    if "frequencies" not in data.keys():
        freqs = None
    else:
        freqs = data["frequencies"]
        freqs = freqs[goodInds]
    Y_f = ptEst_ff[goodInds]
    sigma_f = sigma_ff[goodInds]
    return (freqs,Y_f,sigma_f)

def main():
    pe_parser = argparse.ArgumentParser(add_help=True)
    pe_parser.add_argument(
        "--path_to_file", help="Path to data file (or pickled baseline) to use for analysis.", action="store", type=str, required=True
    )
    pe_parser.add_argument(
        "--ifos", help="List of names of two interferometers for which you want to run PE, default is H1 and L1. Not needed when running from a pickled baseline.", action="store", type=str, required=False, default = "None"
    )
    pe_parser.add_argument(
        "--model", help="Model for which to run the PE. \n Default is \"Power-law\". ", action="store", type=str, required=False, default = "Power-Law"
    )
    pe_parser.add_argument(
        "--model_prior_file", help="The required parameters for the model and their priors. \n Default value is power-law parameters omega_ref with LogUniform bilby prior from 1e-11 to 1e-8 and alpha with Uniform bilby prior from -4 to 4.", type=str, action="store", required=False, default = "None"
    )
    pe_parser.add_argument(
        "--non_prior_args", help="A dictionary with the parameters of the model that are not associated with a prior, such as the reference frequency for Power-Law. Default value is reference frequency at 25 Hz for the power-law model.", type=str, action="store", required=False, default="None"
    )
    pe_parser.add_argument(
        "--output_dir", help="Output directory of the PE (sampler). Default: ./PE_Output.", type=str, action="store", required=False, default='./PE_Output'
    )
    pe_parser.add_argument(
        "--injection_parameters", help="The injected parameters.", required= False, default="None", type=str, action="store"
    )
    pe_parser.add_argument(
        "--quantiles", help="The quantiles used for plotting in plot corner, default is [0.05, 0.95].", type=str, action="store", required=False, default="None"
    )
    pe_parser.add_argument(
        "--f0", help="If no frequencies are saved in the loaded npz file, you can give f0, fhigh and df to the script. This is f0, the begin frequency.", type=float, action="store", default=None
    )
    pe_parser.add_argument(
        "--fhigh", help="If no frequencies are saved in the loaded npz file, you can give f0, fhigh and df to the script. This is fhigh, the end frequency.", type=float, action="store", default=None
    )
    pe_parser.add_argument(
        "--df", help="If no frequencies are saved in the loaded npz file, you can give f0, fhigh and df to the script. This is df, the frequency spacing.", type=float, action="store", default=None
    )
    pe_parser.add_argument(
        "--calibration_epsilon", help="Calibration uncertainty. 0 by default.", type=float, action="store", default=None, required=False,
    )
    
    script_args = pe_parser.parse_known_args()[0]
    
    dictionary_models = {"Power-Law": pe.PowerLawModel, 
                         "Broken-Power-Law": pe.BrokenPowerLawModel,
                         "Triple-Broken-Power-Law": pe.TripleBrokenPowerLawModel,
                         "Smooth-Broken-Power-Law": pe.SmoothBrokenPowerLawModel,
                         "Schumann": pe.SchumannModel,
                         "Parity-Violation": pe.PVPowerLawModel,
                         "Parity-Violation-2": pe.PVPowerLawModel2} #"TVS-Power-Law": pe.TVSPowerLawModel
                         
    dictionary_parameters = {"Power-Law": ["omega_ref","alpha", "fref"],
                             "Broken-Power-Law": ["omega_ref", "fbreak", "alpha_1", "alpha_2"],
                             "Triple-Broken-Power-Law": ["omega_ref", "alpha_1", "alpha_2", "alpha_3", "fbreak1", "fbreak2"],
                             "Smooth-Broken-Power-Law": ["omega_ref", "fbreak", "alpha_1", "alpha_2", "delta"],
                             "Schumann": ["kappa_H1", "kappa_L1", "beta_L1", "beta_H1"],
                             "Parity-Violation": ["omega_ref", "alpha", "Pi"],
                             "Parity-Violation-2": ["omega_ref", "alpha", "beta"]}
    
    if script_args.model not in dictionary_models.keys():
        raise ValueError(
            "The model {script_args.model} you provided is not supported."
        )
    else:
        class_model = dictionary_models[script_args.model]
        
    if script_args.model_prior_file == 'None': 
        dict_loaded = {"omega_ref":bilby.core.prior.LogUniform(1e-11,1e-8,"$\\Omega_{ref}$"),"alpha":bilby.core.prior.Uniform(-4, 4,'$\\alpha$')}
    else:
        dict_loaded = json.load(open(script_args.model_prior_file), object_hook=bilby.core.utils.decode_bilby_json)
        
    if script_args.non_prior_args == 'None':
        extra_kwargs = {"fref": 25}
    else:
        extra_kwargs = json.loads(script_args.non_prior_args)
    
    all_parameters_dict = dict_loaded.copy()
    all_parameters_dict.update(extra_kwargs)
    
    for key in all_parameters_dict.keys():
        if not key in dictionary_parameters[script_args.model]:
            raise AttributeError(
                f"The parameter {key} you provided is not associated with the given model. \n These are the required parameters: {dictionary_parameters[script_args.model]}."
            )
      
    if not all(key in all_parameters_dict.keys() for key in dictionary_parameters[script_args.model]):
        raise ValueError(
            f"Not all required parameters of the model are present in one of the parameters dictionaries. \n These are the required parameters: {dictionary_parameters[script_args.model]}."
        )
    

    if script_args.path_to_file.endswith(("npz")) or script_args.path_to_file.endswith((".h5")):
    
        if script_args.path_to_file.endswith(("npz")):
            frequencies, point_estimate_spectrum, sigma_spectrum = load_npz(script_args.path_to_file)
        else:
            frequencies, point_estimate_spectrum, sigma_spectrum = load_hdf5(script_args.path_to_file)

        
        if frequencies is None:
            frequencies = np.arange(script_args.f0, script_args.fhigh + script_args.df, script_args.df)
        
        point_estimates = [point_estimate_spectrum]
        sigmas = [sigma_spectrum]
        
        if script_args.ifos == "None":
            ifos = ['H1', 'L1']
        else:
            ifos = json.loads(script_args.ifos)
        
        ifo_1 = Interferometer.get_empty_interferometer(ifos[0])
        ifo_2 = Interferometer.get_empty_interferometer(ifos[1])
    
        base_12 = Baseline.from_interferometers([ifo_1, ifo_2])

        if script_args.calibration_epsilon is not None:
            base_12.calibration_epsilon = script_args.calibration_epsilon
    
        if not base_12._orf_polarization_set:
            base_12.orf_polarization = "tensor" # script._args.polarization
    
        base_12.frequencies = frequencies

        base_12.point_estimate_spectrum = point_estimate_spectrum
        base_12.sigma_spectrum = sigma_spectrum
    
    elif script_args.path_to_file.endswith((".p", ".pickle")):
        
        base_12 = Baseline.load_from_pickle(script_args.path_to_file)
    
    else:
        raise TypeError(
            "The provided file format is currently not supported, try an npz, hdf5 file or a pickled baseline instead."
        )
    
    
    kwargs = {"baselines":[base_12], "model_name":script_args.model} #, "fref":script_args.fref}
    kwargs.update(extra_kwargs)
    model = class_model(**kwargs)
    priors = dict_loaded
    if script_args.injection_parameters == "None":
        injection_parameters = dict(omega_ref=10**(-8.67985), alpha=2/3)
    else:
        injection_parameters = json.loads(script_args.injection_parameters)
        for par in injection_parameters:
            injection_parameters[par] = float(Fraction(injection_parameters[par]))
    
    range_list = [(0,0) for i in range(len(priors))]
    for index,key in enumerate(priors.keys()):
        minimum_range = priors[key].minimum
        if minimum_range==-np.inf:
            minimum_range=-4
        maximum_range = priors[key].maximum
        if maximum_range==np.inf:
            maximum_range=4
        tuple_range = (minimum_range, maximum_range)
        range_list[index] = tuple_range
    
    hlv=bilby.run_sampler(likelihood=model,priors=priors,sampler='dynesty', npoints=1000, walks=10, maxmcmc=10000, 
    outdir=script_args.output_dir,label= 'hlv', injection_parameters = injection_parameters) #walks=10, maxmcmc = 10000, npoints= 1000

    if script_args.quantiles == "None":
        quantiles = [0.05, 0.95]
    else:
        quantiles = json.loads(script_args.quantiles)
        
    #print(quantiles)

    hlv.plot_corner(show_titles=True, title_fmt='.3g', use_math_text=True, quantiles=quantiles, range=range_list)
    

if __name__ == "__main__":
    main()
