#! /user/bin/env python

# Copyright (C) 2017 Collin Capano
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.


#
# =============================================================================
#
#                                   Preamble
#
# =============================================================================
#
"""Extracts posterior samples from a Sampler HDF file, writing to a
posterior HDF file. Parameters can be renamed in the output file using the
--parameters option.
"""

import os
import argparse
import logging
import numpy
import pycbc
from pycbc.inference.io import (ResultsArgumentParser, results_from_cli,
                                PosteriorFile, loadfile)

parser = ResultsArgumentParser(defaultparams='all', autoparamlabels=False,
                               description=__doc__)

parser.add_argument("--output-file", type=str, required=True,
                    help="Output file to create.")
parser.add_argument("--force", action="store_true", default=False,
                    help="If the output-file already exists, overwrite it. "
                         "Otherwise, an IOError is raised.")
parser.add_argument("--skip-groups", default=None, nargs="+",
                    help="Don't write the specified groups in the output "
                         "(aside from samples; samples are always written), "
                         "for example, 'sampler_info'. If 'all' write skip "
                         "all groups, only writing the samples. Default is "
                         "to write all groups.")
parser.add_argument("--verbose", action="store_true", default=False,
                    help="Be verbose")

opts = parser.parse_args()

pycbc.init_logging(opts.verbose)

# check that the output doesn't already exist
if os.path.exists(opts.output_file) and not opts.force:
    raise IOError("output file already exists; use --force if you wish to "
                  "overwrite.")

# load the samples
fp, params, labels, samples = results_from_cli(opts)

# convert samples to a dict in which the keys are the labels
samples = {labels[p]: samples[p] for p in params}

# create the file
outtype = PosteriorFile.name
out = loadfile(opts.output_file, 'w', filetype=outtype) 

# write the samples
out.write_samples(samples)

# Preserve samples group metadata
for key, val in fp[fp.samples_group].attrs.items():
    out[out.samples_group].attrs[key] = val

# Preserve top-level metadata...
# ...except for: the filetype, since that's set when the file was loaded;
# the automatic thinning settings, since those will no longer make sense
skip_attrs = ['filetype', 'thin_start', 'thin_interval', 'thin_end',
              'thinned_by']
for key in fp.attrs:
    if key not in skip_attrs:
        out.attrs[key] = fp.attrs[key]

# store what parameters were renamed
out.attrs['remapped_params'] = list(labels.items())

# write the other groups
skip_groups = opts.skip_groups
if skip_groups is not None and 'all' in opts.skip_groups:
    skip_groups = [group for group in fp.keys()
                   if group != fp.samples_group]
fp.copy_info(out, ignore=skip_groups)

# close and exit
fp.close()
out.close()
