#!/bin/env python

import argparse
import distutils
import json
import os
import sys
from pathlib import Path
from typing import List

import bilby
import matplotlib.pyplot as plt
import numpy as np
from loguru import logger

from pygwb.baseline import Baseline
from pygwb.detector import Interferometer
from pygwb.parameters import Parameters, ParametersHelp


def help_arguments(parent):
    ann = getattr(Parameters, "__annotations__", {})
    parser = argparse.ArgumentParser(parents=[parent])
    for name, dtype in ann.items():
        name_help = ParametersHelp[name].help
        if dtype == List:
            parser.add_argument(f"--{name}", help=name_help, type=str, nargs='+', required=False)
        else:
            parser.add_argument(f"--{name}", help=name_help, type=str, required=False)
    return parser

def main():
    params = Parameters()
    pipe_parser = argparse.ArgumentParser(add_help=False)
    pipe_parser.add_argument(
        "--param_file", help="Parameter file to use for analysis.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--output_path", help="Location to save output to.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--calc_coh", help="Calculate coherence spectrum from data.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--calc_pt_est", help="Calculate omega point estimate and sigma from data.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--apply_dsc", help="Apply delta sigma cut when calculating final output.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--pickle_out", help="Pickle output Baseline of the analysis.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--wipe_ifo", help="Wipe interferometer data to reduce size of pickled Baseline.", action="store", type=str, required=False
    )
    pipe_parser.add_argument(
        "--file_tag", help="File naming tag. By default, reads in first and last time in dataset.", action="store", type=str, required=False
    )

    help_args = help_arguments(pipe_parser)
    help_args.parse_known_args()  # for help

    script_args, parameters_args = pipe_parser.parse_known_args()

    if script_args.param_file:
        params.update_from_file(script_args.param_file)
    params.update_from_arguments(parameters_args)
    logger.info(f"Successfully imported parameters from paramfile and input.")

    if script_args.output_path:
        output_path = Path(script_args.output_path)
        if not output_path.exists():
            output_path.mkdir(parents=True)
    else: 
        output_path = Path('')

    if script_args.calc_coh:
        pipe_calculate_coherence = bool(distutils.util.strtobool(script_args.calc_coh))
    else:
        pipe_calculate_coherence = False

    if script_args.calc_pt_est:
        pipe_calculate_point_estimate = bool(distutils.util.strtobool(script_args.calc_pt_est))
    else:
        pipe_calculate_point_estimate = True
    
    if script_args.apply_dsc:
        pipe_apply_dsc = bool(distutils.util.strtobool(script_args.apply_dsc))
    else:
        pipe_apply_dsc = True

    if script_args.pickle_out:
        pipe_pickle_baseline = bool(distutils.util.strtobool(script_args.pickle_out))
    else:
        pipe_pickle_baseline = True

    if script_args.wipe_ifo:
        wipe_ifo = bool(distutils.util.strtobool(script_args.wipe_ifo))
    else:
        wipe_ifo = True

    if not script_args.file_tag:
        script_args.file_tag = f"{int(params.t0)}-{int(params.tf)}"

    if script_args.param_file:
        outfile_path = f"{output_path}/{Path(script_args.param_file).stem}_{script_args.file_tag}_final{Path(script_args.param_file).suffix}"
    else:
        outfile_path = Path(f"{output_path}/parameters_{script_args.file_tag}_final.ini")
    params.save_paramfile(str(outfile_path))
    logger.info(f"Saved final param file at {outfile_path}.")

    param_dict = params.parse_ifo_parameters()
    ifo_list = params.interferometer_list
    ifo_1 = Interferometer.from_parameters(ifo_list[0], param_dict[ifo_list[0]])
    ifo_2 = Interferometer.from_parameters(ifo_list[1], param_dict[ifo_list[1]])
    logger.info(f"Loaded up interferometers with selected parameters.")

    if params.gate_data:
        if params.path_gate_data:
            logger.info("Loading gates from file.")
            params.path_gate_data = Path(params.path_gate_data)
            if not params.path_gate_data.is_file():
                list_of_gatefiles = sorted(params.path_gate_data.glob("point_estimate*.npz"))
                npzobject_list = [path for path in list_of_gatefiles if path.match(f"*{int(params.t0)}-{int(params.tf)}*")]
                npzobject = np.load(params.path_gate_data / npzobject_list[0])
            else:
                npzobject = np.load(params.path_gate_data)
            for index, ifo_obj in enumerate([ifo_1, ifo_2]):
                ifo_obj.apply_gates_from_file(
                    npzobject,
                    index + 1,
                    gate_tpad=param_dict[ifo_list[index]].gate_tpad,
                )
                logger.info(f"Gates loaded and applied for IFO {ifo_list[index]}:{ifo_obj.gates}")

        else:
            logger.info(f"Applying autogating procedure.")
            for ifo, ifo_obj in zip(ifo_list, [ifo_1, ifo_2]):
                gate_params = { 
                        "gate_tzero":param_dict[ifo].gate_tzero,
                        "gate_tpad":param_dict[ifo].gate_tpad,
                        "gate_threshold":param_dict[ifo].gate_threshold,
                        "cluster_window":param_dict[ifo].cluster_window,
                        "gate_whiten":param_dict[ifo].gate_whiten,
                        }
                ifo_obj.gate_data_apply(**gate_params)
                logger.info(f"Gates for IFO {ifo}:{ifo_obj.gates}")

    base_12 = Baseline.from_parameters(ifo_1, ifo_2, params)
    logger.info(
        f"Baseline with interferometers {ifo_1.name}, {ifo_2.name} initialised."
    )

    logger.info(f"Setting PSDs and CSDs of the baseline...")
    base_12.set_cross_and_power_spectral_density(params.frequency_resolution)
    base_12.set_average_power_spectral_densities()
    base_12.set_average_cross_spectral_density()

    logger.info(f"... done.")

    """
    check nothing's gone wrong in the frequency handling...
    """
    deltaF = base_12.frequencies[1] - base_12.frequencies[0]
    if abs(deltaF - params.frequency_resolution) > 1e-6:
        raise ValueError("Frequency resolution in PSD/CSD is different than requested.")

    base_12.calculate_delta_sigma_cut(
        delta_sigma_cut = params.delta_sigma_cut,
        alphas = params.alphas_delta_sigma_cut,
        fref = params.fref,
        flow=params.flow,
        fhigh=params.fhigh,
        return_naive_and_averaged_sigmas = True
    )

    logger.info(
        f"times flagged by the delta sigma cut as badGPStimes:{base_12.badGPStimes}"
    )

    if pipe_calculate_coherence:
        logger.info(f"calculating the coherence...")

        base_12.set_coherence_spectrum(apply_dsc=pipe_apply_dsc)

    if pipe_calculate_point_estimate:
        logger.info(f"calculating the point estimate and sigma...")

        base_12.set_point_estimate_sigma(
            alpha=params.alpha,
            fref=params.fref,
            flow=params.flow,
            fhigh=params.fhigh,
            apply_dsc=pipe_apply_dsc
        )

        logger.success(
            f"Ran stochastic search over times {script_args.file_tag}"
        )
        logger.success(f"\tPOINT ESTIMATE: {base_12.point_estimate:e}")
        logger.success(f"\tSIGMA: {base_12.sigma:e}")

        data_file_name = f"{output_path}/point_estimate_sigma_{script_args.file_tag}"

        logger.info(
            "Saving point_estimate and sigma spectrograms, spectra, and final values to file."
        )
        logger.info("Saving average psds and csd to file.")
        base_12.save_point_estimate_spectra(
            params.save_data_type,
            data_file_name,
        )
        csd_file_name = f"{output_path}/psds_csds_{script_args.file_tag}"
        base_12.save_psds_csds(
            params.save_data_type,
            csd_file_name,
        )
        if pipe_pickle_baseline:
            logger.info("Pickling the baseline.")
            pickle_name = f"{output_path}/{base_12.name}_{script_args.file_tag}.pickle"
            base_12.save_to_pickle(pickle_name, wipe=wipe_ifo)

    else:
        logger.info("Saving average psds and csd to file.")

        data_file_name = f"{output_path}/psds_csds_{script_args.file_tag}"

        base_12.npz_save_psds_csds(
            data_file_name,
            base_12.frequencies,
            base_12.csd,
            base_12.interferometer_1.average_psd,
            base_12.interferometer_2.average_psd,
        )

if __name__ == "__main__":
    main()
