#!python
#
#    JKS - Measurement database system
#    Copyright (C) 2013-2024  Christoph Lehner (christoph.lehner@ur.de, https://github.com/lehner/jks)
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License along
#    with this program; if not, write to the Free Software Foundation, Inc.,
#    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import jks, sys, os, math, glob, time
import numpy as np

nmodels=len(sys.argv)-6
if nmodels < 2:
    print("%s inout.jks model1 model2 [model3 ...] index weight etag tag" % sys.argv[0])
    sys.exit(0)

fn = sys.argv[1]

models=sys.argv[2:2+nmodels]

index=int(sys.argv[2+nmodels])

weight=sys.argv[3+nmodels]
etag=sys.argv[4+nmodels]
tag=sys.argv[5+nmodels]

jks2=jks.resamples(fn)

mm=[jks2.get(m) for m in models]

def aic(data):
    k=data[-1] + 1
    chi2=data[-3]
    aic=2*k + chi2
    return -aic/2.0

def chi2(data):
    chi2=data[-3]
    return -chi2/2.0

def flat(data):
    return 0.0

model_lnp = {
    "aic" : aic,
    "chi2" : chi2,
    "flat" : flat
}[weight]

dist = []
for m in mm:
    mn=m.mean()
    lnp=model_lnp(mn)
    # tcov

    val = mn[index]
    err = m.cov()[index][index] ** 0.5
    dist.append( (val, err, lnp) )

lnp_max=max(dist, key=lambda x: x[2])

ptot = 0.0
for i in range(len(dist)):
    prel = np.exp(dist[i][2] - lnp_max[2])
    dist[i] = (dist[i][0], dist[i][1], prel)
    ptot += prel

for i in range(len(dist)):
    dist[i] = (dist[i][0], dist[i][1], dist[i][2] / ptot)

# first mean
ma_mean = 0.0
ma_var = 0.0
for i in range(len(dist)):
    ma_mean += dist[i][2] * dist[i][0]

jk_new = jks2.apply(lambda r: [sum([r[models[i]][index]*dist[i][2] for i in range(len(dist))])])

for i in range(len(dist)):
    ma_var += (dist[i][0] - ma_mean)**2 * dist[i][2]

jk = jk_new.addvar(etag,[ma_mean + ma_var**0.5],"Systematic error from model average %s" % str(sys.argv))

if "JKS_DIST" in os.environ:
    fo=open(os.environ["JKS_DIST"],"wt")

    vvv = jk.mean()[0]
    eee = jk.tcov()[0][0]**0.5

    vv_min = vvv - 5*eee
    vv_max = vvv + 5*eee

    nbins = 25

    vv_d = (vv_max - vv_min) / nbins

    hist = [0]*nbins

    # sample over models
    for i in range(len(dist)):
        np.random.seed(13)
        mk=jks2.get(models[i])
        mkm=mk.mean()[index]
        mke=mk.tcov()[index][index]**0.5

        vv = []
        for d in range(100000):
            vv.append(np.random.normal(mkm, mke))

        #print(np.mean(vv), np.std(vv), mkm, mke)

        for v in vv:
            v_idx = int((v - vv_min) / vv_d)
            if v_idx >= 0 and v_idx < nbins:
                hist[v_idx] += dist[i][2]

    thist = sum(hist)

    prc = 0.0
    for ih in range(nbins):
        x = vv_min + vv_d * (ih + 0.5)
        pr = hist[ih] / thist
        prc += pr
        fo.write("%.15g %.15g %.15g %.15g\n" % (x,pr,x - vv_d*0.5, prc))
    fo.close()


jks2.add(tag, jk)
jks2.save(fn)
