#!/usr/bin/env python
import sys
import errno
import os
import os.path
import argparse
from six import print_, string_types
import numpy
import scipy.linalg
from libsncompress import base, evaluator


def _info(text, stream=sys.stderr):
    print_(text, file=stream)


def unique(list_of_floats):
    sl = sorted(list_of_floats)
    slu = sorted(list(set(list_of_floats)))
    return sl == slu


def mk_intermediate_dirs(path):
    """Create intermediate directories to the file path."""
    dirname = os.path.dirname(path)
    # Make sure it exists.  This might fail, perhaps due to insufficient
    # privileges, but should never fail due to EEXIST.
    os.makedirs(dirname)


def ensured_save(file_, arr, backend=numpy.savetxt, **kwargs):
    """Wrapper around numpy's data output functions that ensures the
    intermediate directories exist.
    """
    # There's no attempt in guarding the array to be saved against modification
    # during the resumption tries.
    try:
        return backend(file_, arr, **kwargs)
    except IOError as the_error:
        # Following the evil implementation details of numpy.save().  Using the
        # evil isinstance()
        if the_error.errno != errno.ENOENT:     # Definitely not NOENT
            raise                               # Fail as intended
        if not isinstance(file_, string_types):
            raise
        # In this case, file_ is a path.
        mk_intermediate_dirs(file_)
        return ensured_save(file_, arr, backend, **kwargs)


theparser = argparse.ArgumentParser(description="Compress the JLA dataset")
theparser.add_argument("-d", "--covdir", metavar="DIR",
                       help=("path to directory of FITS covariance files "
                             "(default: ./covmat)"),
                       action="store",
                       default="./covmat")
theparser.add_argument("-t", "--table", metavar="FILE",
                       help=("path to JLA data table file "
                             "(default: ./jla_lcparams.txt)"),
                       action="store",
                       default="./jla_lcparams.txt")
theparser.add_argument("-p", "--output-prefix", metavar="PREFIX",
                       help="prefix to output file names",
                       default="")
theparser.add_argument("-c", "--controls", metavar="z", type=float,
                       nargs="+",
                       help="locations of control points (as redshift)",
                       default=[])
theparser.add_argument("-n", "--no-logdet",
                       help=("don't use the correct conditional probability "
                             "(use at your own risk)"),
                       action="store_true")
theparser.add_argument("-v", "--verbose",
                       help="enable verbose output", action="store_true")

optns = theparser.parse_args(sys.argv[1:])
if optns.controls:
    if not unique(optns.controls):
        theparser.error("option -c: non-unique control points")
    if len(optns.controls) < 2:
        theparser.error("option -c: not enough control points (at least 2)")
    logbins = numpy.log10(numpy.array([optns.controls]))
else:
    logbins = None
n = base.BinnedSN(os.path.abspath(optns.covdir),
                  os.path.abspath(optns.table), logbins)

nnuis = 3
npars = nnuis + n.bins.ncontrolpoints

ce = evaluator.CovEvaluator(n, withlogdet=(not optns.no_logdet))
res = ce.minimize()
if optns.verbose:
    _info(res)
if not res.success:
    _info("error: failed to find optimal parameters")
    _info("the message was: %s" % res.message)
    sys.exit(1)
else:
    if optns.verbose:
        _info("info: success")

suffix = "-no-logdet" if optns.no_logdet else ""
ffmt = "% .12e"
covpath = "%s%s%s.txt" % (optns.output_prefix, "cov", suffix)
meanpath = "%s%s%s.txt" % (optns.output_prefix, "mean", suffix)
zpath = "%s%s%s.txt" % (optns.output_prefix, "redshift", suffix)
ensured_save(covpath, ce.compressed_cov, fmt=ffmt)
ensured_save(meanpath, res.x, fmt=ffmt)
ensured_save(zpath,
             numpy.power(10, numpy.fromiter(n.bins.itercontrolpoints(),
                                            dtype=float)),
             fmt=ffmt)
