#!python
#
#    JKS - Measurement database system
#    Copyright (C) 2013-2024  Christoph Lehner (christoph.lehner@ur.de, https://github.com/lehner/jks)
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License along
#    with this program; if not, write to the Free Software Foundation, Inc.,
#    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import jks, sys, os, math, glob, tempfile, subprocess, numpy
from scipy.special import erf

if len(sys.argv) < 7:
    print("%s out.pdf in.jks tag t nbins label" % sys.argv[0])
    sys.exit(0)

fout=sys.argv[1]
fin=sys.argv[2]
tag=sys.argv[3]
t=int(sys.argv[4])
nbins=int(sys.argv[5])
label=sys.argv[6]

fps=tempfile.NamedTemporaryFile()
fpdf=tempfile.NamedTemporaryFile()
desc=tempfile.NamedTemporaryFile()
desc.write(str({ "pwd" : os.getcwd(), "argv" : sys.argv }).encode("utf-8"))
desc.flush()

jks0=jks.resamples(fin)
jks2=jks.resamples()
jks2.take(jks0,tag)
jks2.compress()
print("Compressed to %d configs" % len(jks2.tags))
assert(sorted(jks2.tags) == jks2.tags)

gp = subprocess.Popen(['gnuplot'], stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT)

cmds=""

cmds=cmds+"set terminal postscript color enhanced size 16cm, 12cm\n"
cmds=cmds+"set termoption dash\n"
cmds=cmds+"set output '%s'\n" % fps.name
cmds=cmds+"set style fill empty\n"

#print ret[0]
#import random
#random.seed(10)
#res0=[]
#res1=[]
#for i in range(50000):
#    mm=jks.measurements([ [ random.gauss(2.0,0.8), random.gauss(1.0,0.4) ] for i in range(8) ])
#    p=mm.jackknife().prepare(lambda mi: mi.mean()[0]/mi.mean()[1])
#    res0.append(p.mean())
#    res1.append(p.mean() - p.bias()[0])
#m=jks.measurements(res0)
#print m.mean(), m.jackknife().cov(lambda mi: mi.mean())[0][0]**0.5
#m=jks.measurements(res1)
#print m.mean(), m.jackknife().cov(lambda mi: mi.mean())[0][0]**0.5
#sys.exit(0)
def ac(m,t,maxdist):
    cnr=[ int(a.split("-")[1]) for a in m.config_tags ]
    c2c={}
    c2={}
    avg=m.mean()[t]
    for a in range(len(cnr)):
        for b in range(a,len(cnr)):
            de=cnr[b]-cnr[a]
            if de > maxdist:
                break
            val=(m.data[a][t] - avg)*(m.data[b][t] - avg)
            if not de in c2:
                c2[de]=val
                c2c[de]=1.
            else:
                c2c[de]+=1.
                c2[de]+=val
    for de in c2:
        c2[de] /= c2c[de]
    rp=sorted([ de for de in c2 ])
    rv=[ c2[p] / c2[0] for p in rp ]
    return rp, rv
          
def create_file(tag,t):
    jk=jks2.get(tag)
    m=jk.scaled_measurements()

    avg=m.mean()[t]
    std=m.cov()[t][t]**0.5
    bias=jk.bias()[t]

    values=[ m.data[i][t] for i in range(len(m.data)) ]
    args=numpy.argsort(values)

    if len(values) > 5:
        print("Max 5")
        for i in range(1,6):
            print(jk.tags[args[-i]], values[args[-i]])

        print("Min 5")
        for i in reversed(range(5)):
            print(jk.tags[args[i]], values[args[i]])


    #for i in range(len(values)):
    #    print jk.tags[i], values[i]
    #sys.exit(0)
    vd=max([ math.fabs(v-avg) for v in values])
    v0=avg-vd
    v1=avg+vd
    dv=(v1-v0) / nbins
    bins=[ int(math.floor((v - v0)/dv)) for v in values ]
    bins=[ b if b<nbins else nbins-1 for b in bins ]
    counts=[ bins.count(b) for b in range(nbins) ]
    f=tempfile.NamedTemporaryFile()
    for i in range(nbins):
        f.write(("%.15g %.15g %.15g\n" % (v0+(i+0.5)*dv,counts[i],dv)).encode("utf-8"))
    f.write("\n\n".encode("utf-8"))

    v0s=[ v0 + i*dv for i in range(nbins) ]
    v1s=[ v0 + (i+1)*dv for i in range(nbins) ]
    countsExpected=[ len(values)*0.5*(erf((avg-v0s[i])/math.sqrt(2.)/std) - erf((avg-v1s[i])/math.sqrt(2.)/std)) for i in range(nbins) ]

    for i in range(nbins):
        f.write(("%.15g %.15g %.15g\n" % (v0+(i+0.5)*dv,countsExpected[i],dv)).encode("utf-8"))

    f.write("\n\n".encode("utf-8"))
    f.flush()

    N=len(values)
    nmaxbin=2
    if len(values) % nmaxbin != 0:
        n0=len(values)
        values=values[n0 % nmaxbin:]
        print("Warning: configuration number %d not divisible by %d, remove one" % (n0,nmaxbin))
        print(jk.tags)

    mt=jks.measurements(values)
    mtt=mt

    idx=0
    binErr={}
    for blocks in [ 1, 2 ]:
        if blocks != 1:
            mtt=mt.block(blocks)
        avgP=mtt.mean()[0]
        binErr[blocks]=mtt.jackknife().cov(lambda mi: mi.mean()[0])[0][0]**0.5
        f.write(("%.15g %.15g %.15g\n" % (idx,avgP,binErr[blocks])).encode("utf-8"))
        f.write("\n\n".encode("utf-8"))
        f.flush()
        idx+=1

    print("B2/B1 = %g" % (binErr[2]/binErr[1]))
    #print "B4/B2 = %g" % (binErr[4]/binErr[2])
    #print "B8/B4 = %g" % (binErr[8]/binErr[4])
    #print "B16/B8 = %g" % (binErr[16]/binErr[8])
    
    fit={}
    for parts in [ 2, 4, 8 ]:
        ngrp=len(values)//parts

        for blocks in [ 1 ]:
            res=[]

            idx=1
            for i in range(parts):
                print("Part %d / %d = %s" % (i,parts,str(jk.tags[i*ngrp:(i+1)*ngrp])))
                mtt=mt.subset(range(i*ngrp,(i+1)*ngrp))
                if blocks != 1:
                    mtt=mtt.block(blocks)
                v,e=mtt.mean()[0],mtt.jackknife().cov(lambda mi: mi.mean()[0])[0][0]**0.5
                res.append( (v,e) )
                f.write(("%.15g %.15g %.15g\n" % (idx,v,e)).encode("utf-8"))
                f.flush()
                idx+=1

            f.write("\n\n".encode("utf-8"))
            fit[(parts,blocks)]=jks.plateau([ x[0] for x in res ], [ [ res[i][1]**2. if i==j else 0.0 for i in range(len(res)) ] for j in range(len(res)) ],
                                              range(len(res))).fit()

    # Now plot autocorrelator for each ensemble
    ensembles=set([ a.split("-")[0] for a in jks2.tags ])
    assert(len(ensembles)==1)
    cnr=[ int(a.split("-")[1]) for a in jks2.tags ]
    dists=sorted(list(set([ cnr[i+1] - cnr[i] for i in range(len(cnr)-1) ])))
    maxdist=dists[0]*10
    
    rp,rv=ac(m,t,maxdist)
    rcv=m.jackknife().cov(lambda mi: ac(mi,t,maxdist)[1])
    for de in range(len(rp)):
        f.write(("%d %.15g %.15g\n" % (rp[de],rv[de],rcv[de][de]**0.5)).encode("utf-8"))
    f.flush()

    return f, avg, std, bias, v0-dv, v1+dv, N, avgP, fit, maxdist

nplt=0
def plot(cc):
    global nplt,cmds
    if nplt == 0:
        cmds=cmds+"plot "
    else:
        cmds=cmds+", "
    cmds=cmds+cc
    nplt=nplt+1

f,avg,std,bias,vl,vh,N,avgP,fit,maxdist = create_file(tag,t)

print("Avg: %g" % avg)
print("Std: %g" % std)
print("Bias: %g" % bias)

cmds=cmds+"\n"
#cmds=cmds+"set arrow 1 from (%.15g), graph 0 to (%.15g), graph 1 nohead\n" % (avg-bias,avg-bias)
cmds=cmds+"set xrange [0:*]\n"
cmds=cmds+"set yrange [(%.15g):(%.15g)]\n" % (vl,vh)

# Process commands
cc="'%s' index 1 using ($2*0.5):($1):($2*0.5):(0.5*$3) w boxxyerrorbars lt 2 lw 0.8 title 'Expected Gaussian Bins'," % (f.name)
cc=cc+"'' index 0 using ($2*0.5):($1):($2*0.5):(0.5*$3) w boxxyerrorbars lt 1 lw 1 title '%s'" % (label)
cc=cc+", '' index 0 using 2:1:(sqrt($2)) w xerr lt 1 lw 0.8 notitle"
plot(cc)

cmds=cmds+"\n"
#"unset arrow 1\n"
cmds=cmds+"unset xtics\n"
cmds=cmds+"set xrange [-1:3]\n"
cmds=cmds+"set yrange [(%.15g):(%.15g)]\n" % (avgP-std*3.0/math.sqrt(N),avgP+std*4.0/math.sqrt(N))
cmds=cmds+"set key top left Left\n"
nplt=0

cc="'%s' index 2 using 1:2:3 w yerr lt 1 lw 1.0 title 'Original'" % (f.name)
cc=cc+", '' index 3 using 1:2:3 w yerr lt 2 lw 1.0 title 'Block-by-2'"
cc=cc+", (%.15g) lt 8 title 'Bias-corrected mean'" % (avgP-bias)
plot(cc)   

nplt=0
cmds=cmds+"\n"
cmds=cmds+"set xrange [-1:3]\n"
cmds=cmds+"set label 1 \"p=%.2g\" at graph 1, graph 1 offset -1,-1 right front\n" % fit[(2,1)]["p"]
cc="'%s' index 2 using 1:2:3 w yerr lt 1 lw 1.0 title 'Original'" % (f.name)
cc=cc+", '' index 4 using 1:2:3 w yerr lt 3 lw 1.0 title '#/2'"
cc=cc+", (%.15g) lt 8 title 'Bias-corrected mean'" % (avgP-bias)
plot(cc)   


cmds=cmds+"\n"
cmds=cmds+"set label 1 \"p=%.2g\" at graph 1, graph 1 offset -1,-1 right front\n" % fit[(4,1)]["p"]
cmds=cmds+"set xrange [-1:5]\n"
cmds=cmds+"set yrange [(%.15g):(%.15g)]\n" % (avgP-std*20.0/math.sqrt(N),avgP+std*20.0/math.sqrt(N))
nplt=0
cc="'%s' index 2 using 1:2:3 w yerr lt 1 lw 1.0 title 'Original'" % (f.name)
cc=cc+", '' index 5 using 1:2:3 w yerr lt 3 lw 1.0 title '#/4'"
cc=cc+", (%.15g) lt 8 title 'Bias-corrected mean'" % (avgP-bias)
plot(cc)   

cmds=cmds+"\n"
cmds=cmds+"set label 1 \"p=%.2g\" at graph 1, graph 1 offset -1,-1 right front\n" % fit[(8,1)]["p"]
cmds=cmds+"set xrange [-1:9]\n"
cmds=cmds+"set yrange [(%.15g):(%.15g)]\n" % (avgP-std*20.0/math.sqrt(N),avgP+std*20.0/math.sqrt(N))
nplt=0
cc="'%s' index 2 using 1:2:3 w yerr lt 1 lw 1.0 title 'Original'" % (f.name)
cc=cc+", '' index 6 using 1:2:3 w yerr lt 3 lw 1.0 title '#/8'"
cc=cc+", (%.15g) lt 8 title 'Bias-corrected mean'" % (avgP-bias)
plot(cc)   

cmds=cmds+"\n"

cmds=cmds+"f(x)=exp(-x/tauexp)\n"
cmds=cmds+"tauexp=%d\n" % maxdist
cmds=cmds+"set yrange [-1:1]\n"
cmds=cmds+"set xrange [0:%d]\n" % (maxdist+1)
cmds=cmds+"fit f(x) '%s' index 7 via tauexp\n" % f.name
cmds=cmds+"\n"

cmds=cmds+"\n"
cmds=cmds+"unset label 1\n"
cmds=cmds+"set key top right\n"
cmds=cmds+"set xtics\n"
nplt=0

#print cmds
cc="'%s' index 7 using 1:2:3 w yerr lt 1 lw 1.0 title 'Autocorrelation', 0 lt 8 notitle, f(x) title sprintf(\"{/Symbol t}_{exp}=%%g\", tauexp)" % (f.name)
plot(cc)   


cmds=cmds+"\n"


#print cmds
ret=gp.communicate(input=cmds.encode("utf-8"))
#print ret
subprocess.Popen(["ps2pdf",fps.name,fpdf.name],stdout=subprocess.PIPE).communicate()
subprocess.Popen(["pdfcrop",fpdf.name,fout],stdout=subprocess.PIPE).communicate()
subprocess.Popen(["exiftool",fout,"-Description<=%s" % desc.name,"-overwrite_original"],stdout=subprocess.PIPE).communicate()

