#!python
#
#    JKS - Measurement database system
#    Copyright (C) 2013-2024  Christoph Lehner (christoph.lehner@ur.de, https://github.com/lehner/jks)
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License along
#    with this program; if not, write to the Free Software Foundation, Inc.,
#    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import stat, jks, sys, os, math, glob, tempfile, subprocess

argv=[ x for x in sys.argv ]
keep=False
force=True

if "-k" in argv:
    keep=True
    argv.remove("-k")

#if "-f" in argv:
#    force=True
#    argv.remove("-f")

if len(argv) < 4:
    print("%s out.pdf in.jks cmd1 [cmd2 ...]" % argv[0])
    sys.exit(0)

fout=argv[1]
fin=argv[2]
icmds=argv[3:]

targetDir="%s.input" % fout
if os.path.isdir(targetDir):
    if not force:
        sys.stderr.write("%s already exists, do not overwrite (use -f if you insist)\n" % targetDir)
        sys.exit(1)
else:
    os.mkdir(targetDir)

fps_name="%s/plots.ps" % targetDir
fpdf_name="%s/plots.pdf" % targetDir
desc_name="%s/desc.txt" % targetDir
plot_name="%s/plots.plt" % targetDir
make_name="%s/make" % targetDir

desc=open(desc_name,"wt")
desc.write(str({ "pwd" : os.getcwd(), "argv" : sys.argv }))
desc.close()

jks2=jks.resamples(fin)

cmds=""

#cmds=cmds+"set terminal postscript color enhanced size 16cm, 12cm\n"
#cmds=cmds+"set output '%s'\n" % fps_name
cmds=cmds+"set terminal pdfcairo color enhanced font 'Helvetica,12' size 12cm,9cm fontscale 0.75\n"
cmds=cmds+"set output '%s'\n" % fpdf_name
cmds=cmds+"set style fill transparent solid 0.4 noborder\n"
cmds=cmds+"set pointintervalbox 0.01\n"
cmds=cmds+"xmap(x)=x\n"

#print ret[0]


files=[]
fid=-1
def NamedTemporaryFile():
    global fid
    fid+=1
    return open("%s/data.%3.3d" % (targetDir,fid),"wt")

def create_file(tag):
    jk=jks2.get(tag)
    mn=jk.mean()
    cv=jk.cov()
    tcv=jk.tcov()
    f=NamedTemporaryFile()
    for i in range(len(mn)):
        f.write("%d %.15g %.15g %.15g\n" % (i,mn[i],tcv[i][i]**0.5,cv[i][i]**0.5))
    f.flush()
    f.close()
    return f.name

def create_file_p(xtag,ytag,sel=None):
    xjk=jks2.get(xtag)
    xmn=xjk.mean()
    xcv=xjk.cov()
    xtcv=xjk.tcov()
    
    yjk=jks2.get(ytag)
    ymn=yjk.mean()
    ycv=yjk.cov()
    ytcv=yjk.tcov()

    assert(len(xmn) == len(ymn))

    f=NamedTemporaryFile()
    for i in range(len(xmn)):
        if not sel is None:
            if not i in sel:
                continue
        f.write("%.15g %.15g %.15g %.15g %.15g %.15g\n" % (xmn[i],ymn[i],xtcv[i][i]**0.5,ytcv[i][i]**0.5,xcv[i][i]**0.5,ycv[i][i]**0.5))
    f.flush()
    f.close()
    return f.name

def create_file_d(x,y,yerr):
    f=NamedTemporaryFile()
    f.write("%s %s %s\n" % (x,y,yerr))
    f.flush()
    f.close()
    return f.name

def create_fnc_file(tag,fncs,t0,t1):
    fnc=eval("lambda x,p: %s" % fncs)
    jk=jks2.get(tag)
    mn=jk.mean()
    tcv=jk.tcov()
    f=NamedTemporaryFile()

    eps=1e-8
    N=50
    xrang=[ t0 + (t1-t0)*i/N for i in range(N+1) ]
    jks.write_confidence_band(fnc,mn,tcv,eps,xrang,f.name)
    f.flush()
    f.close()
    return f.name

nplt=0
def plot(cc):
    global nplt,cmds
    if nplt == 0:
        cmds=cmds+"plot "
    else:
        cmds=cmds+", "
    cmds=cmds+cc
    nplt=nplt+1

def get_val(tag,n,fmt):
    jk=jks2.get(tag)
    mn=jk.mean()

    if n < 0:
        n += len(mn)
        
    mn=mn[n]
    if fmt == None:
        err=jk.tcov()[n][n]**0.5
        if err == 0.0:
            return "%.2g" % mn
        return jks.gformat(mn,{ "" : err},[""],times="x")
    else:
        return fmt % mn

def has(tag):
    return (tag in jks2.keys())

def mktitle(tg):
    while True:
        i=tg.find("***")
        if i == -1:
            break
        e=tg[i+3:].find("***")
        if e == -1:
            break
        a=tg[i+3:e+i+3].split(" ")
        if len(a)==1:
            n=0
            fmt=None
        elif len(a)==2:
            n=int(a[1])
            fmt=None
        elif len(a)==3:
            n=int(a[1])
            fmt=a[2]
        else:
            return tg
        tg=tg[:i] + get_val(a[0],n,fmt) + tg[i+e+6:]
    return tg

# Process commands
for c in icmds:
    if c[0] == "c":
        a=c[1:].split(":")
        if has(a[1]):
            files.append(create_file(a[1]))
            cc="'%s' using (xmap($1)):2:4 w yerr lt %s lw 0.5 ps 0 notitle, " % (files[-1],a[0])
            if len(a)>2 and a[2]!="":
                cc=cc+"'' using (xmap($1)):2:3 w yerr lt %s title '%s'" % (a[0],mktitle(a[2]))
            else:
                cc=cc+"'' using (xmap($1)):2:3 w yerr lt %s notitle" % (a[0])
            plot(cc)
    elif c[0] == "e":
        a=c[1:].split(":")
        if has(a[1]):
            files.append(create_file(a[1]))
            cc="'%s' using (xmap($1)):4 w points lt %s lw 0.5 ps 0 notitle, " % (files[-1],a[0])
            if len(a)>2 and a[2]!="":
                cc=cc+"'' using (xmap($1)):3 w points lt %s title '%s'" % (a[0],mktitle(a[2]))
            else:
                cc=cc+"'' using (xmap($1)):3 w points lt %s notitle" % (a[0])
            plot(cc)    
    elif c[0] == "p":
        a=c[1:].split(":")
        if has(a[1]) and has(a[2]):
            files.append(create_file_p(a[1],a[2]))
            cc="'%s' using (xmap($1)):2:5:6 w xyerr lt %s lw 0.5 ps 0 notitle, " % (files[-1],a[0])
            if len(a)>3 and a[3]!="":
                cc=cc+"'' using (xmap($1)):2:3:4 w xyerr lt %s title '%s'" % (a[0],mktitle(a[3]))
            else:
                cc=cc+"'' using (xmap($1)):2:3:4 w xyerr lt %s notitle" % (a[0])
            plot(cc)
    elif c[0] == "s":
        a=c[1:].split(":")
        if has(a[1]) and has(a[2]):
            files.append(create_file_p(a[1],a[2],eval(a[3])))
            cc="'%s' using (xmap($1)):2:5:6 w xyerr lt %s lw 0.5 ps 0 notitle, " % (files[-1],a[0])
            if len(a)>4 and a[4]!="":
                cc=cc+"'' using (xmap($1)):2:3:4 w xyerr lt %s title '%s'" % (a[0],mktitle(a[4]))
            else:
                cc=cc+"'' using (xmap($1)):2:3:4 w xyerr lt %s notitle" % (a[0])
            plot(cc)
    elif c[0] == "P":
        a=c[1:].split(":")
        if has(a[1]) and has(a[2]):
            files.append(create_file_p(a[1],a[2]))
            cc="'%s' using (xmap($1)):2:5:6 w xyerr lt %s lw 0.5 ps 0 notitle, " % (files[-1],a[0])
            if len(a)>3 and a[3]!="":
                cc=cc+"'' using (xmap($1)):2:3:4 w xyerr lt %s lw 2 title '%s'" % (a[0],mktitle(a[3]))
            else:
                cc=cc+"'' using (xmap($1)):2:3:4 w xyerr lt %s lw 2 notitle" % (a[0])
            plot(cc)
    elif c[0] == "Q":
        a=c[1:].split(":")
        if has(a[1]) and has(a[2]):
            files.append(create_file_p(a[1],a[2]))
            cc="'%s' using (xmap($1)):2:5:6 w xyerr lt %s lw 0.5 ps 0 notitle, " % (files[-1],a[0])
            if len(a)>3 and a[3]!="":
                cc=cc+"'' using (xmap($1)):2:3:4 w xyerr lt %s lw 4 title '%s'" % (a[0],mktitle(a[3]))
            else:
                cc=cc+"'' using (xmap($1)):2:3:4 w xyerr lt %s lw 4 notitle" % (a[0])
            plot(cc)
    elif c[0] == "d":
        a=c[1:].split(":")
        files.append(create_file_d(a[1],a[2],a[3]))
        cc="'%s' using (xmap($1)):2:3 w yerr lt %s lw 1 title '%s'" % (files[-1],a[0],a[4])
        plot(cc)
    elif c[0] == "b":
        a=c[1:].split(":")
        files.append(create_file(a[1]))
        cc="'%s' using 1:2:4 w yerr lt %s lw 1 ps 0 notitle, " % (files[-1],a[0])
        cc=cc+"'' using 1:2:3 w yerr lt %s lw 2 notitle" % (a[0])
        plot(cc)    
    elif c[0] == "f":
        a=c[1:].split(":")
        files.append(create_fnc_file(a[1],a[2],float(a[3]),float(a[4])))
        if len(a) < 6:
            cc="'%s' using 1:($2-$3):($2+$3) w filledcu lw 1.5 lt %s notitle, " % (files[-1],a[0])
        else:
            cc="'%s' using 1:($2-$3):($2+$3) w filledcu lw 1.5 lt %s title '%s', " % (files[-1],a[0],mktitle(a[5]))
        cc=cc+"'' using 1:2 w lines lt %s notitle" % (a[0])
        plot(cc)    
    elif c[0:2] == "ls":
        a=c.split(":")
        if len(a) == 1:
            cmds=cmds+"unset logscale\n"
        else:
            cmds=cmds+"set logscale %s\n" % a[1]
    elif c[0] == "l":
        a=c[1:].split(":")
        if len(a) == 3:
            cc="%s lt %s title '%s'" % (a[1],a[0],a[2])
        elif len(a) == 2:
            cc="%s lt %s notitle" % (a[1],a[0])
        else:
            assert(0)
        plot(cc)
    elif c[0] == "L":
        a=c[1:].split(":")
        if len(a) == 3:
            cc="%s lt %s lw 2 title '%s'" % (a[1],a[0],a[2])
        elif len(a) == 2:
            cc="%s lt %s lw 2 notitle" % (a[1],a[0])
        else:
            assert(0)
        plot(cc)
    elif c[0:5] == "xwrap":
        a=c.split(":")
        T=int(a[1])
        T2=int(a[2])
        cmds=cmds+("xmap(x)=(int(x + %d) %% %d) - %d\n") % (T2-T,T2,T2-T)
    elif c[0:1] == "k":
        a=c.split(":")
        cmds=cmds+"set key %s\n" % (a[1])
    elif c[0:2] == "xr":
        a=c.split(":")
        cmds=cmds+"set xrange [%s:%s]\n" % (a[1],a[2])
    elif c[0:2] == "xl":
        a=c.split(":")
        if len(a)==1:
            cmds=cmds+"unset xlabel\n"
        else:
            cmds=cmds+"set xlabel '%s'\n" % (mktitle(a[1]))
    elif c[0:2] == "xt":
        a=c.split(":")
        n=len(a)-1
        if n == 0:
            cmds=cmds+"set xtics norotate\n"
            cmds=cmds+"set xtics ()\n"
        else:
            cmds=cmds+"set xtics rotate by -45 ("
            for l in range(n):
                if l != 0:
                    cmds=cmds+", "
                cmds=cmds+"\"" + a[l+1]+"\" "+str(l)
            cmds=cmds+")\n"
    elif c[0:2] == "yt":
        a=c.split(":")
        n=len(a)-1
        if len(a) == 0:
            cmds=cmds+"set ytics norotate\n"
            cmds=cmds+"set ytics\n"
        else:
            cmds=cmds+"set ytics rotate by -45 ("
            for l in range(n):
                if l != 0:
                    cmds=cmds+", "
                cmds=cmds+"\"" + a[l+1]+"\" "+str(l)
            cmds=cmds+")\n"
    elif c[0:2] == "yl":
        a=c.split(":")
        if len(a)==1:
            cmds=cmds+"unset ylabel\n"
        else:
            cmds=cmds+"set ylabel '%s'\n" % (mktitle(a[1]))
    elif c[0:2] == "vl":
        a=c.split(":")
        cmds=cmds+"set arrow from %s, graph 0 to %s, graph 1 nohead front\n" % (a[1],a[1])
    elif c[0:2] == "yr":
        a=c.split(":")
        cmds=cmds+"set yrange [%s:%s]\n" % (a[1],a[2])
    elif c == "newpage":
        nplt=0
        cmds=cmds+"\n"
    else:
        print("Unknown command %s" % c)

cmds=cmds+"\n"

#print(cmds)

plot=open(plot_name,"wt")
plot.write(cmds)
plot.close()

make=open(make_name,"wt")
make.write("#!/bin/bash\n" +
("gnuplot '%s'\n" % plot_name) +
#("ps2pdf '%s' '%s'\n" % (fps_name,fpdf_name)) +
("pdfcrop '%s' '%s' 1>/dev/null\n" % (fpdf_name,fout)) +
("exiftool '%s' '-Description<=%s' -overwrite_original 1>/dev/null\n" % (fout,desc_name)))
make.close()
os.chmod(make_name,stat.S_IRWXU|stat.S_IRWXG)

gp = subprocess.Popen([make_name], stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.STDOUT)
ret=gp.communicate(input="")
if ret[0] != b"":
    print("Error: %s" % ret[0].decode("utf-8"))

if keep == False:
    for fn in [ desc_name, fpdf_name, plot_name, make_name ] + files: #fps_name,
        os.unlink(fn)
    os.rmdir(targetDir)

