#!/bin/env python

import argparse
import copy
import os
from fractions import Fraction
from os import listdir
from os.path import isdir, isfile, join
from pathlib import Path

import numpy as np
from loguru import logger
from tqdm import tqdm

from pygwb.baseline import Baseline
from pygwb.coherence import calculate_coherence
from pygwb.constants import h0 as pygwb_h0
from pygwb.omega_spectra import OmegaSpectrum, reweight_spectral_object
from pygwb.parameters import Parameters
from pygwb.postprocessing import (
    calc_Y_sigma_from_Yf_sigmaf,
    combine_spectra_with_sigma_weights,
)


def sort_ptest_files(item):
    res = item.split("point_estimate_sigma_")[-1]
    # double check if the tag is longer than expected
    res = res.split('_')[-1]
    return float(res.split("-")[0])


def sort_coh_files(item):
    res = item.split("psds_csds_")[-1]
    # double check if the tag is longer than expected
    res = res.split('_')[-1]
    return float(res.split("-")[0])


"""
SCRIPT TO COMBINE PYGWB_PIPE RUN OUTPUTS.
currently only works with npz - this will be updated for compatibility with other formats as we go along.
"""


def main():
    combine_parser = argparse.ArgumentParser()
    combine_parser.add_argument(
        "--data_path", help="Path to data files or folder.", action="store", type=Path, nargs='+'
    )
    combine_parser.add_argument(
        "--alpha",
        help="Spectral index alpha to use for spectral re-weighting.",
        action="store",
        type=str,
    )
    combine_parser.add_argument(
        "--fref",
        help="Reference frequency to use when presenting results.",
        action="store",
        type=int,
    )
    combine_parser.add_argument(
        "--param_file", help="Parameter file", action="store", type=str
    )
    combine_parser.add_argument(
        "--h0",
        help="Value of h0 to use. Default is pygwb.constants.h0.",
        action="store",
        type=float,
        required=False,
    )
    combine_parser.add_argument(
        "--combine_coherence", help="Calculate combined coherence over all available data.", action="store", type=bool, required=False, default=False
    )
    combine_parser.add_argument(
        "--coherence_path", help="Path to coherence data files, if individual files are passed.", action="store", type=Path, nargs='+', required=False, default=None
    )
    combine_parser.add_argument(
        "--out_path", help="Output path.", action="store", type=Path, required=False
    )
    combine_parser.add_argument(
        "--file_tag", help="File naming tag. By default, reads in first and last time in dataset.", action="store", type=str, required=False
    )
    combine_args = combine_parser.parse_args()
    if not combine_args.h0:
        combine_args.h0 = pygwb_h0
    if not combine_args.out_path:
        combine_args.out_path = Path("./")
    if not isdir(combine_args.out_path):
        os.mkdir(combine_args.out_path)
    if combine_args.coherence_path:
        combine_args.combine_coherence=True
    combine_args.alpha = float(Fraction(combine_args.alpha)) 
    params = Parameters()
    params.update_from_file(combine_args.param_file)

    if isfile(combine_args.data_path[0]):
        files_ptest = [str(pt_file) for pt_file in combine_args.data_path]
        if combine_args.combine_coherence:
                    files_coh = [str(coh_file) for coh_file in combine_args.coherence_path]
    elif isdir(combine_args.data_path[0]):
        files_ptest = [
            os.path.join(combine_args.data_path[0], f)
            for f in listdir(combine_args.data_path[0])
            if isfile(join(combine_args.data_path[0], f)) and f.startswith("point")
        ]
        if combine_args.combine_coherence:
            files_coh = [
                join(combine_args.data_path[0], f)
                for f in listdir(combine_args.data_path[0])
                if isfile(join(combine_args.data_path[0], f)) and f.startswith("psds")
            ]
    else:
        raise ValueError("Data path provided is neither a file nor a folder.")

    try:
        files_ptest.sort(key=sort_ptest_files)
    except:
        logger.warning('Unable to sort point estimate files! '
                        'This may have unintended effects.')
    if combine_args.combine_coherence:
        try:
            files_coh.sort(key=sort_coh_files)
        except:
            logger.warning('Unable to sort coherence files! '
                            'This may have unintended effects.')
    else:
        files_coh = None

    times = [int(sort_ptest_files(files_ptest[idx])) for idx in range(len(files_ptest))]

    if not combine_args.file_tag:
        combine_args.file_tag = f"{times[0]}-{times[-1]}"

    frequencies = np.load(files_ptest[0])["frequencies"]
    frequency_mask = np.load(files_ptest[0])["frequency_mask"]

    # spectral objects
    Y_j = []
    sigma_j = []
    Y_spectra_j = []
    sigma_spectra_j = []
    Y_seg = []
    sigma_seg = []
    # DQ objects
    naive_sigmas_j = []
    slide_sigmas_j = []
    delta_sigmas_j = []
    badGPStimes_j = []
    times_j = []
    gates_ifo1_j = []
    gates_ifo2_j = []
    gates_ifo1_pad_j = []
    gates_ifo2_pad_j = []

    pt_est_sigma_unweighted_path = os.path.join(
        combine_args.out_path,
        f"point_estimate_sigma_{combine_args.file_tag}_UNWEIGHTED.npz",
    )
    delta_sigma_cut_output_path = os.path.join(
        combine_args.out_path,
        f"delta_sigma_cut_{combine_args.file_tag}.npz",
    )
    pt_est_sigma_spectra_path = os.path.join(
        combine_args.out_path,
        "point_estimate_sigma_spectra_alpha_{:.1f}".format(combine_args.alpha)
        + f"_fref_{combine_args.fref}_{combine_args.file_tag}.npz",
    )

    logger.info('Unpacking files...')
    for file in tqdm(files_ptest):
        data_file = np.load(file, allow_pickle=True)
        Y_j.append(data_file["point_estimate"])
        sigma_j.append(data_file["sigma"])
        Y_spectra_j.append(data_file["point_estimate_spectrum"])
        sigma_spectra_j.append(data_file["sigma_spectrum"])
        naive_sigmas_j.append(data_file["naive_sigma_values"].T)
        slide_sigmas_j.append(data_file["slide_sigma_values"].T)
        delta_sigmas_j.append(data_file["delta_sigma_values"].T)
        times_j.append(data_file["delta_sigma_times"])
        gates_ifo1_j.append(data_file["ifo_1_gates"])
        gates_ifo2_j.append(data_file["ifo_2_gates"])
        gates_ifo1_pad_j.append(data_file["ifo_1_gate_pad"])
        gates_ifo2_pad_j.append(data_file["ifo_2_gate_pad"])
        if not data_file["badGPStimes"].size == 0:
            if data_file["badGPStimes"].size == 1:
                list_gps = np.array(data_file["badGPStimes"].flatten())
                badGPStimes_j.append(list_gps)
            else:
                badGPStimes_j.append(data_file["badGPStimes"])
        
        Y_s, sigma_s = calc_Y_sigma_from_Yf_sigmaf(
            data_file["point_estimate_spectrogram"], data_file["sigma_spectrogram"], frequency_mask=frequency_mask
        )
        if not np.shape(Y_s):
            Y_s = np.array([Y_s])
        if not np.shape(sigma_s):
            sigma_s = np.array([sigma_s])
        Y_seg.append(Y_s)
        sigma_seg.append(sigma_s)

    Y_seg = np.concatenate(Y_seg)
    sigma_seg = np.concatenate(sigma_seg)
    Y_j = np.array(Y_j)
    sigma_j = np.array(sigma_j)
    np.savez(pt_est_sigma_unweighted_path, point_estimate=Y_j, sigma=sigma_j, point_estimate_per_seg=Y_seg, sigma_per_seg=sigma_seg)
    logger.info(
        f"saved file with unweighted point estimate and sigma values for all times in run:\n {pt_est_sigma_unweighted_path}."
    )

    naive_sigmas_j = np.concatenate(naive_sigmas_j)
    slide_sigmas_j = np.concatenate(slide_sigmas_j)
    delta_sigmas_j = np.concatenate(delta_sigmas_j)
    times_j = np.concatenate(times_j)
    # concatenate gate info - filtering for empty sets of gates
    if sum([1 for g in gates_ifo1_j if g.size>1])>0: 
        gates_ifo1_j = np.concatenate(list(filter(lambda x: x.size>1, gates_ifo1_j)))
    if sum([1 for g in gates_ifo2_j if g.size>1])>0: 
        gates_ifo2_j = np.concatenate(list(filter(lambda x: x.size>1, gates_ifo2_j)))
    gates_ifo1_pad_j = np.array(gates_ifo1_pad_j)
    gates_ifo2_pad_j = np.array(gates_ifo2_pad_j)
    try:
        badGPStimes_j = np.concatenate(badGPStimes_j)
    except ValueError: # when there are no badGPStimes in the whole set 
        badGPStimes_j = badGPStimes_j
    np.savez(
        delta_sigma_cut_output_path,
        naive_sigmas=naive_sigmas_j,
        slide_sigmas=slide_sigmas_j,
        delta_sigmas=delta_sigmas_j,
        badGPStimes=badGPStimes_j,
        times=times_j,
        gates_ifo1=gates_ifo1_j,
        gates_ifo2=gates_ifo2_j,
        gates_ifo1_pad=gates_ifo1_pad_j,
        gates_ifo2_pad=gates_ifo2_pad_j,
    )
    logger.info(
        f"saved file with all sigma information related to the delta sigma cut for all times in run:\n {delta_sigma_cut_output_path}."
    )

    # make combination robust against weird zero-sigma segments...
    for spec in sigma_spectra_j:
        spec[np.where(spec==0.0)] = np.inf
    Y_spectrum_combined, sigma_spectrum_combined = combine_spectra_with_sigma_weights(
        np.array(Y_spectra_j), np.array(sigma_spectra_j)
    )

    Y_spectrum = OmegaSpectrum(
        Y_spectrum_combined,
        alpha=params.alpha,
        fref=params.fref,
        h0=pygwb_h0,
        name="Y_spectrum",
        frequencies=frequencies,
    )
    sigma_spectrum = OmegaSpectrum(
        sigma_spectrum_combined,
        alpha=params.alpha,
        fref=params.fref,
        h0=pygwb_h0,
        name="sigma_spectrum",
        frequencies=frequencies,
    )

    try:
        Y_spectrum.write(
        os.path.join(combine_args.out_path, f"Y_spectrum_{combine_args.file_tag}_UNWEIGHTED.hdf5")
        )
    except OSError:
        pass
    try:
        sigma_spectrum.write(
        os.path.join(combine_args.out_path, f"sigma_spectrum_{combine_args.file_tag}_UNWEIGHTED.hdf5")
        )
    except OSError:
        pass

    logger.info(
        f"Saved file with combined point estimate and sigma OmegaSpectrum objects for this run. These are weighted with alpha={params.alpha}"
    )

    Y_estimate, sigma_estimate = calc_Y_sigma_from_Yf_sigmaf(
        Y_spectrum,
        sigma_spectrum,
        frequency_mask=frequency_mask,
        alpha=combine_args.alpha,
        fref=combine_args.fref,
    )
    Y_estimate *= (Y_spectrum.h0 / combine_args.h0) ** 2
    sigma_estimate *= (sigma_spectrum.h0 / combine_args.h0) ** 2

    logger.info(
        "Final point estimate re-weighted with alpha={:.2f}".format(combine_args.alpha)
        + f" at reference frequency fref={combine_args.fref} with h0={combine_args.h0}:\n [{Y_estimate} +/- {sigma_estimate}]"
    )

    Y_reweight_spectrum = OmegaSpectrum(
        Y_spectrum_combined,
        alpha=params.alpha,
        fref=params.fref,
        h0=pygwb_h0,
        name="Y_spectrum",
        frequencies=frequencies,
    )
    Y_reweight_spectrum.reweight(
        new_alpha=combine_args.alpha, new_fref=combine_args.fref
    )
    sigma_reweight_spectrum = OmegaSpectrum(
        sigma_spectrum_combined,
        alpha=params.alpha,
        fref=params.fref,
        h0=pygwb_h0,
        name="sigma_spectrum",
        frequencies=frequencies,
    )
    sigma_reweight_spectrum.reweight(
        new_alpha=combine_args.alpha, new_fref=combine_args.fref
    )
    Y_reweight_spectrum.reset_h0(new_h0=combine_args.h0)
    sigma_reweight_spectrum.reset_h0(new_h0=combine_args.h0)

    try:
        Y_reweight_spectrum.write(
        os.path.join(combine_args.out_path, "Y_spectrum_alpha_{:.1f}".format(combine_args.alpha) + f"_fref_{combine_args.fref}_{combine_args.file_tag}.hdf5")
    )
    except OSError:
        pass
    try:
        sigma_reweight_spectrum.write(
            os.path.join(combine_args.out_path, "sigma_spectrum_alpha_{:.1f}".format(combine_args.alpha) + f"_fref_{combine_args.fref}_{combine_args.file_tag}.hdf5")
        )
    except OSError:
        pass

    np.savez(
        pt_est_sigma_spectra_path,
        point_estimate=Y_estimate,
        sigma_estimate=sigma_estimate,
        point_estimate_spectrum=Y_reweight_spectrum.value,
        sigma_spectrum=sigma_reweight_spectrum.value,
        frequencies=frequencies,
        frequency_mask=frequency_mask,
        point_estimates_seg_UW=Y_seg,
        sigmas_seg_UW=sigma_seg,
    )
    logger.info(
        f"Saved file with re-weighted point estimate and sigma values and spectra:\n {pt_est_sigma_spectra_path}."
    )

    if combine_args.combine_coherence:
        coherence_path = os.path.join(
            combine_args.out_path,
            f"coherence_spectrum_{combine_args.file_tag}.npz",
        )
        psd_1_average = None
        psd_2_average = None
        csd_average = None
        coh_frequencies = np.load(files_coh[0], allow_pickle=True)['avg_freqs']
        n_segs_coh_total = 0

        logger.info(
                f"Combining coherences over all files..."
        )
        for file_coh, file_pt in tqdm(zip(files_coh,files_ptest), total=len(files_coh)):
            data_file = np.load(file_coh, allow_pickle=True)
            psd_1_coh = data_file['psd_1_coh']
            psd_2_coh = data_file['psd_2_coh']
            csd_coh = data_file['csd_coh']

            if not (np.isnan(psd_1_coh).all() or np.isnan(psd_2_coh).all() or np.isnan(csd_coh).all()):
                n_segs_coh = data_file['n_segs_coh']
                n_segs_coh_total += n_segs_coh

                if psd_1_average is None:
                    psd_1_average = n_segs_coh*psd_1_coh
                    psd_2_average = n_segs_coh*psd_2_coh
                    csd_average = n_segs_coh*csd_coh

                else:
                    psd_1_average += n_segs_coh*psd_1_coh
                    psd_2_average += n_segs_coh*psd_2_coh
                    csd_average += n_segs_coh*csd_coh
            else:
                logger.info(f"Removed a coherence set from the data as results were all NaN.")

        if psd_1_average is not None and psd_2_average is not None:
            combined_coherence = calculate_coherence(psd_1_average/n_segs_coh_total, psd_2_average/n_segs_coh_total, csd_average/n_segs_coh_total)
        else:
            logger.info(f"Removed all the coherences due to bad data.")
            combined_coherence = None
            n_segs_coh_total = 0
        np.savez(coherence_path, coherence=combined_coherence, frequencies=coh_frequencies, n_segs_coh=n_segs_coh_total)

        logger.info(
            f"Saved file with coherence spectrum:\n {coherence_path}."
        )

    exit()


if __name__ == "__main__":
    main()

