#! /usr/bin/env python
# Copyright (C) 2016  Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
__description__ = \
""""Loads a bank of waveforms and compresses them using the specified
compression algorithm. The resulting compressed waveforms are saved to an
hdf file."""

import argparse
import numpy
import h5py
import logging
import pycbc
from pycbc.waveform import compress
from pycbc import waveform
from pycbc.types import FrequencySeries, real_same_precision_as
from pycbc import pnutils
from pycbc import filter


parser = argparse.ArgumentParser(description=__description__)
parser.add_argument("--bank-file", type=str,
                help="Bank hdf file to load.")
parser.add_argument("--output", type=str,
                help="The hdf file to save the templates and compressed "
                     "waveforms to.")
parser.add_argument("--low-frequency-cutoff", type=float, required=True,
                help="The low frequency cutoff to use for generating "
                     "the waveforms (Hz)")
# add approximant arg
pycbc.waveform.bank.add_approximant_arg(parser)
parser.add_argument("--sample-rate", type=int, required=True,
                help="Half this value sets the maximum frequency of the "
                     "compressed waveforms.")
parser.add_argument("--tmplt-index", nargs=2, type=int, default=None,
                help="Only generate compressed waveforms for the given "
                     "indices in the bank file. Must provide both a start "
                     "and a stop index. Default is to compress all of the "
                     "templates in the bank file.")
parser.add_argument("--compression-algorithm", type=str, required=True,
                choices=compress.compression_algorithms.keys(),
                help="The compression algorithm to use for selecting "
                     "frequency points.")
parser.add_argument("--t-pad", type=float, default=0.001,
                help="The minimum duration used for t(f) in seconds. The "
                     "inverse of this gives the maximum frequency step that "
                     "will be used in the compressed waveforms. Default is "
                     "0.001.")
parser.add_argument("--tolerance", type=float, default=0.001,
                help="The maximum mismatch to allow between the interpolated "
                     "waveform must and the full waveform. Points will be "
                     "added to the compressed waveform until its "
                     "interpolation has a mismatch <= this value. Default is "
                     "0.001.")
parser.add_argument("--interpolation", type=str, default="linear",
                help="The interpolation to use for decompressing the "
                     "waveforms for checking tolerance. Options are 'linear', "
                     "or any interpolation recognized by scipy's interp1d "
                     "kind argument. Default is linear.")
parser.add_argument("--precision", type=str, choices=["double", "single"],
                default="double",
                help="What precision to generate and store the waveforms with;"
                     " default is double.")
parser.add_argument("--force", action="store_true", default=False,
                    help="Overwrite the given hdf file if it exists. "
                         "Otherwise, an error is raised.")
parser.add_argument("--verbose", action="store_true", default=False)


args = parser.parse_args()

pycbc.init_logging(args.verbose)

fmin = args.low_frequency_cutoff
fmax = args.sample_rate/2.

# load the bank
logging.info("loading bank")
# set the dtype to use
# FIXME: there really should be a function in pycbc.types that provides this
# mapping
if args.precision == "single":
    dtype = numpy.complex64
else:
    dtype = numpy.complex128
# we'll just use dummy values for N, df for now
bank = waveform.FilterBank(args.bank_file, 5, 0.25, fmin, dtype,
    approximant=args.approximant)
templates = bank.table
if args.tmplt_index is not None:
    imin, imax = args.tmplt_index
    templates = templates[imin:imax]
    bank.table = templates
else:
    imin, imax = 0, templates.size

# figure out the dfs needed for each waveform
logging.info("getting needed dfs")
# we'll ensure that the there are atleast 2 samples in a waveform
seg_lens = 4*numpy.array([max(4./args.sample_rate,
    2**numpy.ceil(numpy.log2(compress.rough_time_estimate(m1, m2, fmin))))
    for m1,m2 in zip(templates.mass1, templates.mass2)])

# generate output file
logging.info("writing template info to output")
output = bank.write_to_hdf(args.output, force=args.force,
    skip_fields='template_duration')

# get the compressed sample points for each template
logging.info("getting compressed amplitude and phase")
mismatches = numpy.zeros(imax-imin, dtype=float)
for ii in range(imin, imax):
    # update the N, df to use based on this waveform
    N = int(args.sample_rate*seg_lens[ii-imin])
    df = 1./seg_lens[ii-imin]
    bank.delta_f = df
    bank.N = N
    bank.filter_length = N/2 + 1
    # We'll budge the waveform starting frequency down by 1df to ensure 
    # the first point is correct
    bank.f_lower = fmin-df
    bank.kmin = int(bank.f_lower / df)
    # scratch space
    decomp_scratch = FrequencySeries(numpy.zeros(N, dtype=dtype), delta_f=df)
    # generate the waveform
    htilde = bank[ii-imin]
    tmplt = bank.table[ii-imin]
    kmax = htilde.end_idx
    
    # get the compressed sample points
    if args.compression_algorithm == 'mchirp':
        sample_points = compress.mchirp_compression(tmplt.mass1, tmplt.mass2,
            fmin, (kmax-1)*df, min_seglen=args.t_pad, df_multiple=df).astype(
            real_same_precision_as(htilde))
    elif args.compression_algorithm == 'spa':
        sample_points = compress.spa_compression(htilde, fmin, (kmax-1)*df,
            min_seglen=args.t_pad).astype(real_same_precision_as(htilde))
    else:
        raise ValueError("unrecognized compression algorithm %s" %(
            args.compression_algorithm))
            
    # compress
    hcompressed = compress.compress_waveform(
        htilde, sample_points, args.tolerance, args.interpolation,
        decomp_scratch=decomp_scratch)

    # save results
    hcompressed.write_to_hdf(output, tmplt.template_hash)

logging.info("finished")
bank.filehandler.close()
