#!python
#
#    JKS - Measurement database system
#    Copyright (C) 2013-2024  Christoph Lehner (christoph.lehner@ur.de, https://github.com/lehner/jks)
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License along
#    with this program; if not, write to the Free Software Foundation, Inc.,
#    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import jks, os, sys, glob, numpy, math

if len(sys.argv) < 7:
    print("%s out in ops fmtC fmtO dt" % sys.argv[0])
    sys.exit(0)

fout=sys.argv[1]
fin=sys.argv[2]
ops=sys.argv[3].split(",")
fmt=sys.argv[4]
fmtO=sys.argv[5]
dt=int(sys.argv[6])
jkscale=1.0
if len(sys.argv) > 7:
    jkscale=float(sys.argv[7])
    print("Use a jkscale factor of %g" % jkscale)
flagLenientSignLambda=False
if len(sys.argv) > 8:
    flagLenientSignLambda=True
    print("Lenient in the sign of lambda")


# GEVP: 
#   C(t) u_n = lambda_n(t-t0) C(t0) u_n
#   inv(C(t0)) C(t) u_n = lambda_n(t-t0) u_n  -> find eigenvalues and eigenvectors of inv(C(t0)) C(t)

# V_ni = <n| O_i |0>

# C_ij(t) = sum_n <0| Oi |n>  <n| Oj |0> e^{-E_n t}
#         = sum_n dag(V)_in e^{-E_n t} V_nj 

# C(t) u_m = dag(V)_im e^{-E_m t} (Vu_m)_m = dag(V)_im e^{-E_m t} b
# dag(u_m) C(t) u_m = dag(Vu_m)_m (Vu_m)_m e^{-E_m t} = |b|^2 e^{-E_m t}

# b = (Vu_m)_m

#
# ( C(t) u_m ) / sqrt( dag(u_m) C(t) u_m ) = dag(V)_im e^{-E_m t/2} (b/|b|)
#

def con(c):
    return [ x.conjugate() for x in c ]

def cmp(r,i):
    assert(len(r) == len(i))
    return [ complex(r[j],i[j]) for j in range(len(r)) ]

def cmp0(r):
    return [ complex(r[j]) for j in range(len(r)) ]

def get(r,a,b):
    f1=fmt % (a,b)
    f2=fmt % (b,a)

    if f1 in r or f2 in r:
        if f1 in r:
            return cmp0(r[f1])
        return cmp0(r[f2])
    else:
        if f1 + ".r" in r:
            return cmp(r[f1+".r"],r[f1+".i"])
        return con(cmp(r[f2+".r"],r[f2+".i"]))

def gevp(r):
    c00=get(r,ops[0],ops[0])
    T=len(c00)
    trange=range(T)

    N=len(ops)
    C=dict([ (t,numpy.array( [ [ get(r,ops[i],ops[j])[t] for j in range(N) ] for i in range(N) ] )) for t in trange ])

    res=[ ]

    nrm=[ numpy.linalg.norm(C[t]) for t in range(T) ]
    if dt > 0:
        rr=range(0,T-dt)
    else:
        rr=range(-dt,T)

    for t in rr:
        t0=t+dt

        En=[ float("nan") for n in range(N) ]
        cn2=[ [ float("nan") for n in range(N) ] for m in range(N) ]
        uni=[ [ float("nan") for n in range(2*N) ] for m in range(N) ]

        if not math.isnan(nrm[t]) and not math.isnan(nrm[t0]):
            C0inv=numpy.linalg.inv(C[t0])
            M=numpy.dot(C0inv,  C[t])
            res0,res1=numpy.linalg.eig(M)

            if flagLenientSignLambda:
                res0=numpy.array(res0.real,dtype=numpy.complex128)

            if type(res0[0])==numpy.complex128 and (len(numpy.where( res0.real <= 0.0)[0])==0 or flagLenientSignLambda) and len(numpy.where( abs(res0.imag) >= 1e-10)[0])==0:
                En=(-numpy.log(res0.real) / (t-t0)).tolist()
                idx=numpy.argsort(En).tolist()
                En=[ En[i] for i in idx ]
                if En[0] > 0.0:
                    for n in range(N):
                        u_n=res1[:,idx[n]]
                        uni[n]=[]
                        for x in u_n:
                            uni[n].append(x.real)
                            uni[n].append(x.imag)

                        for m in range(N):
                            cn2[n][m]= 1.0

                        Navg=0.
                        for tp in [t]:
                            Cu_n=numpy.dot(  C[tp], u_n )
                            uCu=numpy.vdot(u_n, Cu_n)
                            for m in range(N):
                                cn2[n][m] *= abs(Cu_n[m])**2.0 / abs(uCu) * math.exp( En[n] * tp )
                            Navg+=1.

                        for m in range(N):
                            cn2[n][m]= math.pow( cn2[n][m], 1.0 / Navg )

        res = res + En
        for n in range(N):
            res = res + cn2[n]
        for n in range(N):
            res = res + uni[n]

    return res
       
jks2=jks.resamples(fin) 

env_keep_fixed=os.getenv("STATS_KEEP_FIXED")
if env_keep_fixed != None:
    print("Warning: keeping %s fixed" % env_keep_fixed)
    env_keep_fixed = list(set([ y for x in env_keep_fixed.split(",") for y in jks2.match(x) ]))
    print("Translated: " + str(env_keep_fixed))
jk=jks2.apply(gevp,scale=jkscale, keep_fixed = env_keep_fixed)

N=len(ops)
block=N+N*N+2*N*N

def part(mm,t):
    mb=mm[t*block:(t+1)*block]
    En=mb[0:N]
    cnn=mb[N:N+N*N]
    cn2=[ [ cnn[i+N*j] for j in range(N) ] for i in range(N) ]
    cuni=mb[N+N*N:]
    unir=[ [ cuni[(j+N*n)*2+0] for j in range(N) ] for n in range(N) ]
    unii=[ [ cuni[(j+N*n)*2+1] for j in range(N) ] for n in range(N) ]
    return En, cn2, unir, unii

def get_En(mm,m): # m is the state
    return [ part(mm,t)[0][m] for t in range(len(mm)//block) ]

def get_cn2(mm,n,m): # n is the operator, m is the state
    return [ part(mm,t)[1][n][m] for t in range(len(mm)//block) ]

def get_cni_r(mm,n,i): # n is the state, i is the operator
    return [ part(mm,t)[2][n][i] for t in range(len(mm)//block) ]

def get_cni_i(mm,n,i): # n is the state, i is the operator
    return [ part(mm,t)[3][n][i] for t in range(len(mm)//block) ]

tmp=jks.resamples()
tmp.add("full",jk)

res=jks.resamples()
for m in range(N):
    jk=tmp.apply(lambda r: get_En(r["full"],m))
    res.add(fmtO % ("En-" + str(m)),jk) # switch n and m meaning for files for historical consistency
    for n in range(N):
        jk=tmp.apply(lambda r: get_cn2(r["full"],n,m))
        res.add(fmtO % ("c2mn-" +str(n)+"-"+str(m)),jk) # n<>m meaning, see above

        jk=tmp.apply(lambda r: get_cni_r(r["full"],n,m))
        res.add(fmtO % ("cni-" +str(n)+"-"+str(m)+".r"),jk)

        jk=tmp.apply(lambda r: get_cni_i(r["full"],n,m))
        res.add(fmtO % ("cni-" +str(n)+"-"+str(m)+".i"),jk)

#res.compress()
res.save(fout)
