#!python
#
#    JKS - Measurement database system
#    Copyright (C) 2013-2024  Christoph Lehner (christoph.lehner@ur.de, https://github.com/lehner/jks)
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License along
#    with this program; if not, write to the Free Software Foundation, Inc.,
#    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
#
# Title:  Jackknife / Variation file reader
# Author: Christoph Lehner
# Date:   2018
#
import pickle, sys, numpy, math, os
import lz4.frame

# Treat special case of single element
def mka(a):
    if type(a)==type([]):
        return a
    return [ [ a ] ]

# Human readable format of errors
def format(value,errors,etags,force_no_digits_shift=False):
    err_digits=2
    error=max( [ math.fabs(errors[e]) for e in etags ] )
    val_digits_left_of_period=int(math.floor(1e-10 + math.log(math.fabs(value)) / math.log(10.)) + 1)
    if error == 0.0:
        return "%.15g" % (value)
    else:
        err_digits_left_of_period=int(math.floor(1e-10 + math.log(error) / math.log(10.)) + 1)

    digits_left_of_period=max([ val_digits_left_of_period, err_digits_left_of_period ])
    digits_right_of_period=err_digits_left_of_period - err_digits
    if digits_left_of_period < 0:
        digits_shift=-digits_left_of_period+1
    elif digits_right_of_period > 0:
        digits_shift=-digits_right_of_period
    else:
        digits_shift=0
    if force_no_digits_shift:
        digits_shift=0
    value*=10**digits_shift
    errors_rescaled=dict([ (e,errors[e]*10**-digits_right_of_period) for e in etags ])
    digits_left_of_period+=digits_shift
    digits_right_of_period+=digits_shift

    if digits_right_of_period > 0:
        digits_right_of_period = 0
        
    s=("{0:%d.%df}" % (digits_left_of_period,-digits_right_of_period)).format(value)
    for e in etags:
        if errors[e] == 0.0:
            continue
        if digits_right_of_period == -1:
            s = s + "(" + "{0:02.1f}".format(errors_rescaled[e]/10.) + ")_{%s}" % e
        else:
            s = s + "(" + "{0:02.0f}".format(errors_rescaled[e]) + ")_{%s}" % e
    if digits_shift != 0:
        if digits_shift == -1:
            s+=" \\times 10"
        else:
            s+=" \\times 10^{%d}" % (-digits_shift)
    return s

# Main logic
if len(sys.argv) > 1:
    fn=sys.argv[1]

    try:
        with lz4.frame.open(fn, mode='rb') as f:
            p = pickle.load(f, encoding='latin1', errors='')
    except:
        with open(fn, 'rb') as f:
            p = pickle.load(f, encoding='latin1')

    conf_tags = list(filter(lambda x: x[0] != "!", p["tags"]))
    var_tags = [ y[1:] for y in filter(lambda x: x[0] == "!", p["tags"]) ]
    if len(sys.argv) >= 3:
        tag=sys.argv[2]
        mean=numpy.array(p["set"][tag]["orig"])
        stat_blocks=[ numpy.array(p["set"][tag]["blocks"][i]) - mean for i in range(len(p["tags"])) if p["tags"][i][0] != "!" ]
        vars = dict([ (v,mka((numpy.cov(m=[ numpy.array(p["set"][tag]["blocks"][p["tags"].index("!"+v)]), mean ],rowvar=0, ddof=1)*2.0).tolist())) for v in var_tags ])
        n=len(stat_blocks)
        if n > 0:
            if "BIN" in os.environ:
                m=int(os.environ["BIN"])
                n -= n % m
                assert n % m == 0
                
                sys.stderr.write(f"# bin {m} configurations: {[conf_tags[i*m:(i+1)*m] for i in range(n//m)]}\n")
                
                # first construct resamples
                cyclic_average = True

                if not cyclic_average:
                    rec_stat_blocks = [ (n-m)/m/(n//m-1)*numpy.sum(stat_blocks[i*m:(i+1)*m],axis=0) for i in range(n//m) ]
                    rec_stat_cov=mka(((n//m-1)**2.0/(n//m)*numpy.cov(m=rec_stat_blocks, rowvar=0, ddof=1)).tolist())
                    stat_cov = rec_stat_cov
                else:
                    # something goes wrong here
                    a = []
                    for ic in range(m):
                        stat_blocks_cycled = numpy.roll(stat_blocks, ic, axis=0)
                        rec_stat_blocks = [ (n-m)/m/(n//m-1)*numpy.sum(stat_blocks_cycled[i*m:(i+1)*m],axis=0) for i in range(n//m) ]
                        rec_stat_cov=mka(((n//m-1)**2.0/(n//m)*numpy.cov(m=rec_stat_blocks, rowvar=0, ddof=1)).tolist())
                        a.append(rec_stat_cov)
                    stat_cov = numpy.mean(a, axis=0)                
            else:
                stat_cov=mka(((n-1)**2.0/n*numpy.cov(m=stat_blocks, rowvar=0, ddof=1)).tolist())
        else:
            stat_cov=[ [ 0.0 for i in range(len(mean)) ] for j in range(len(mean)) ]
        if len(sys.argv) == 3:
            print("# t, c[t], stat, sys, (stat^2+sys^2)^0.5")
            for i in range(len(mean)):
                sys_cov = sum([ vars[v][i][i] for v in var_tags ])
                if "JKS_INFO_HUMAN_READABLE" in os.environ:
                    errors={ "stat": stat_cov[i][i]**0.5, "sys": sys_cov**0.5 }
                    sys.stdout.write("%d %s\n" % (i, format(mean[i],errors,["stat","sys"],True)))
                else:
                    sys.stdout.write("%d %.15g %.15g %.15g %.15g\n" % (i, mean[i], stat_cov[i][i]**0.5, sys_cov**0.5, (stat_cov[i][i] + sys_cov)**0.5))
        elif len(sys.argv) == 4:
            t=int(sys.argv[3])
            sys.stdout.write("%.15g" % mean[t])
            if stat_cov[t][t] != 0.0:
                sys.stdout.write(" +- %.15g(stat)" % (stat_cov[t][t]**0.5))
            for v in var_tags:
                if vars[v][t][t] != 0.0:
                    sys.stdout.write(" +- %.15g(%s)" % (vars[v][t][t]**0.5,v))
            sys.stdout.write("\n")
            etags = [ "stat" ] + sorted(var_tags)
            assert("stat" not in var_tags)
            errors=dict([ ("stat",stat_cov[t][t]**0.5) ] + [ (v,vars[v][t][t]**0.5) for v in var_tags ])
            print(format(mean[t],errors,etags))
    else:
        print("--------------------------------------------------------------------------------")
        print(" Contents of %s" % fn)
        print("--------------------------------------------------------------------------------")
        if "origin" in p:
            print("Origin: ")
            for v in p["origin"]:
                print("- %s: %s" % (v," "*(14-len(v)) + str(p["origin"][v])))
            print("--------------------------------------------------------------------------------")
        print("%d configs:" % len(conf_tags))
        print(conf_tags)
        if len(var_tags) > 0:
            print("--------------------------------------------------------------------------------")
        for f in var_tags:
            print("Variation '%s':" % f)
            if f in p["info"]:
                print("\"%s\"" % p["info"][f])
        print("--------------------------------------------------------------------------------")
        print("Tags:")
        print(list(p["set"].keys()))
        print("--------------------------------------------------------------------------------")

else:
    print("jks_info file [tag [element]]")
    print("")
    print("Without argument, show file overview")
    print("With single argument, print correlator with statistical, systematic, and errors combined in quadrature")
    print("With two arguments list value and all errors in LaTeX usable format")

