#!python
#
#    JKS - Measurement database system
#    Copyright (C) 2013-2024  Christoph Lehner (christoph.lehner@ur.de, https://github.com/lehner/jks)
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License along
#    with this program; if not, write to the Free Software Foundation, Inc.,
#    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import jks, os, sys, glob, numpy, math

if (len(sys.argv) != 10 and len(sys.argv) != 12):
    print("%s inout corrsin ops keepstates tstar fmtCI fmtCO fmtc frozenCoef [ opsLeft opsRight ]" % sys.argv[0])
    sys.exit(0)

finout=sys.argv[1]
fcorrsin=sys.argv[2]
ops=sys.argv[3].split(",")
keepstates=sys.argv[4].split(",")
tstar=int(sys.argv[5])
fmtCI=sys.argv[6]
fmtCO=sys.argv[7]
fmtc=sys.argv[8]
frozenCoef=eval(sys.argv[9])
opsLeft=[]
opsRight=[]
if len(sys.argv) == 12:
    opsLeft=filter(lambda x:x!="",sys.argv[10].split(","))
    opsRight=filter(lambda x:x!="",sys.argv[11].split(","))

N=len(ops)
Nkeep=len(keepstates)

def con(c):
    return [ x.conjugate() for x in c ]

def get_real(c):
    return [ x.real for x in c ]

def get_imag(c):
    return [ x.imag for x in c ]

def cmp(r,i):
    assert(len(r) == len(i))
    return [ complex(r[j],i[j]) for j in range(len(r)) ]

def cmp0(r):
    return [ complex(r[j]) for j in range(len(r)) ]

def get(r,a,b):
    f1=fmtCI % (a,b)
    f2=fmtCI % (b,a)


    if f1 in r or f2 in r:
        if f1 in r:
            return cmp0(r[f1])
        return cmp0(r[f2])
    else:
        if f1 + ".r" in r:
            return cmp(r[f1+".r"],r[f1+".i"])
        return con(cmp(r[f2+".r"],r[f2+".i"]))

def rotate(r,n,m):
    c00=get(r,ops[0],ops[0])
    T=len(c00)
    trange=range(T)

    N=len(ops)
    C=dict([ (t,numpy.array( [ [ get(r,ops[i],ops[j])[t] for j in range(N) ] for i in range(N) ] )) for t in trange ])

    un=[]
    um=[]
    for i in range(N):
        f=fmtc % ("cni-%d-%d" % (n,i))
        un.append(cmp(r[f+".r"],r[f+".i"])[tstar])
        f=fmtc % ("cni-%d-%d" % (m,i))
        um.append(cmp(r[f+".r"],r[f+".i"])[tstar])
    un=numpy.array(un)
    um=numpy.array(um)

    # C_nm = u_n^*_i C_ij u_m^j
    res=[ 0.0 for t in range(T) ]
    for t in range(T):
        res[t] += numpy.dot(numpy.dot(numpy.conj(un), C[t]),um)
    return res

def rotate_opleft(r,op,m):
    c00=get(r,ops[0],ops[0])
    T=len(c00)
    trange=range(T)

    N=len(ops)
    C=dict([ (t,numpy.array( [ get(r,op,ops[j])[t] for j in range(N) ])) for t in trange ])

    um=[]
    for i in range(N):
        f=fmtc % ("cni-%d-%d" % (m,i))
        um.append(cmp(r[f+".r"],r[f+".i"])[tstar])
    um=numpy.array(um)

    # C_opm = C_op,j u_m^j
    res=[ 0.0 for t in range(T) ]
    for t in range(T):
        res[t] += numpy.dot(C[t],um)
    return res

def rotate_opright(r,n,op):
    c00=get(r,ops[0],ops[0])
    T=len(c00)
    trange=range(T)

    N=len(ops)
    C=dict([ (t,numpy.array( [ get(r,ops[i],op)[t] for i in range(N) ] )) for t in trange ])

    un=[]
    for i in range(N):
        f=fmtc % ("cni-%d-%d" % (n,i))
        un.append(cmp(r[f+".r"],r[f+".i"])[tstar])
    un=numpy.array(un)

    # C_n,op = u_n^*_i C_i,op
    res=[ 0.0 for t in range(T) ]
    for t in range(T):
        res[t] += numpy.dot(numpy.conj(un), C[t])
    return res

res=jks.resamples(finout)
jks2=jks.resamples(fcorrsin)
jks2.take(res,"*")
for n in range(Nkeep):
    for m in range(Nkeep):
        if frozenCoef:
            kf=[ fmtc % ("cni-%d-%d.r" % (n,i)) for i in range(N) ] + [ fmtc % ("cni-%d-%d.i" % (n,i)) for i in range(N) ]
        else:
            kf=[]
        # in future make more efficient
        jk=jks2.apply(lambda r: get_real(rotate(r,n,m)),keep_fixed = kf)
        res.add((fmtCO % ("%d" % n,"%d" % m)) + ".r",jk)
        jk=jks2.apply(lambda r: get_imag(rotate(r,n,m)),keep_fixed = kf)
        res.add((fmtCO % ("%d" % n,"%d" % m)) + ".i",jk)

    # Now add left
    for op in opsLeft:
        jk=jks2.apply(lambda r: get_real(rotate_opleft(r,op,n)),keep_fixed = kf)
        res.add((fmtCO % (op,"%d" % n)) + ".r",jk)
        jk=jks2.apply(lambda r: get_imag(rotate_opleft(r,op,n)),keep_fixed = kf)
        res.add((fmtCO % (op,"%d" % n)) + ".i",jk)

    for op in opsRight:
        jk=jks2.apply(lambda r: get_real(rotate_opright(r,n,op)),keep_fixed = kf)
        res.add((fmtCO % ("%d" % n,op)) + ".r",jk)
        jk=jks2.apply(lambda r: get_imag(rotate_opright(r,n,op)),keep_fixed = kf)
        res.add((fmtCO % ("%d" % n,op)) + ".i",jk)

#res.compress()
res.save(finout)
