#!python
#
#    JKS - Measurement database system
#    Copyright (C) 2013-2024  Christoph Lehner (christoph.lehner@ur.de, https://github.com/lehner/jks)
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License along
#    with this program; if not, write to the Free Software Foundation, Inc.,
#    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import jks, sys, os, math, glob, numpy as np
if len(sys.argv) < 4 or len(sys.argv) % 2 != 0:
    print("%s database.jks tag1 op1 [tag2 op2 ...]" % sys.argv[0])
    print("")
    print("Info:")
    print("- the input dictionary is 'r'")
    print("- numpy is defined as np")
    print("- Available functions: fold, partial_sum, rpartial_sum, cshift, shift, foldsum, emp_log, emp_cosh, interpolate")
    
    sys.exit(0)

Nop=1 + (len(sys.argv) - 4) // 2

def cmerge(r,corrs):
    return reduce(lambda x,y: np.append(x,y), [ np.array(r[c[0]])[c[1]] for c in corrs ])

def getkey(corrs):
    return str([ (c[0],c[1]) for c in corrs ])

def first_available(a, *keys):
    for k in keys:
        if k in a:
            return a[k]
    raise Exception("%s not found in array" % str(keys))

cov={}
lastguess={}
def cfit(r,lfits,guess):
    return gfit(r,lfits,guess,corr=True)

def fit(r,lfits,guess):
    return gfit(r,lfits,guess,corr=False)

def gfitf(x,p,corrs):
    x0=0
    for c in corrs:
        x1=x0+len(c[1])
        if x0 <= x and x < x1:
            xl=list(c[1])[x-x0]
            return c[2](xl,p)
        x0=x1

    assert(0)

def gfit(r,lfits,guess,corr):
    global res
    mn=cmerge(r,lfits)
    key=getkey(lfits)
    if key not in cov:
        print("Creating cov %s" % key)
        jk=res.apply(lambda ri: cmerge(ri,lfits).tolist())
        cv=jk.cov()
        if not corr:
            cv=[ [ cv[i][i] if i==j else 0.0 for i in range(len(cv)) ] for j in range(len(cv)) ]
        cov[key]=cv
    else:
        cv=cov[key]

    if key in lastguess:
        guess=lastguess[key]
        fr=jks.qfit(mn,cv,range(len(mn)),lambda x,p: gfitf(x,p,lfits),guess,verbose=False,maxiter=1000,tolerance=1e-8,simplex=True)
        #print "use", guess
        # TODO: implement non-simplex newton-like stable method for this case
        print(fr)
    else:
        fr=jks.qfit(mn,cv,range(len(mn)),lambda x,p: gfitf(x,p,lfits),guess,verbose=False,maxiter=1000,tolerance=1e-4,simplex=True)
        print(fr)
        fr=jks.qfit(mn,cv,range(len(mn)),lambda x,p: gfitf(x,p,lfits),fr["val"],verbose=False,maxiter=1000,tolerance=1e-8,simplex=True)
        lastguess[key]=fr["val"]
        print("Init")
        print(fr)
    #sys.exit(0)

    #if fr["success"] == False:
    #    fr=jks.qfit(mn,cv,range(len(mn)),lambda x,p: gfitf(x,p,lfits),fr["val"],verbose=False,maxiter=10000,tolerance=1e-8)
    if fr["success"] == False:
        return [ float("nan") for i in range(len(guess) + 3) ]
    ret=fr["val"]
    if corr:
        ret = ret + [ fr["chi2"], fr["dof"], fr["p"] ]
    else:
        ret = ret + [ float("nan") for i in range(3) ]
    return ret

def time_reverse(C):
    T = len(C)
    return np.array([ C[(T-t) % T] for t in range(T) ])

def apb(C, t0):
    if t0 > 0:
        return np.concatenate((C[0:-t0], C[-t0:]*(-1)), axis=0)
    return C
        
def interpolate(y,y_val,x,order):
    sign_prev = 0
    for i in range(len(y)):
        sign = 1 if y[i] > y_val else -1
        if sign_prev * sign < 0:
            # we have crossed the value
            if order == 1:
                # y[i-1] + lam*(y[i]-y[i-1]) == y_val
                # ->
                # lam = (y_val - y[i-1]) / (y[i] - y[i-1])
                lam = (y_val - y[i-1]) / (y[i] - y[i-1])
                x = x[i-1] + lam*(x[i]-x[i-1])
                return np.array([x])
            elif order == 2:
                # solution is between i-1 and i, find out which is closer
                lam = (y_val - y[i-1]) / (y[i] - y[i-1])
                if lam > 0.5:
                    i0 = i
                    lam -= 1.0
                else:
                    i0 = i-1
                c1 = (y[i0 + 1] - y[i0 - 1])/2.
                c2 = (y[i0 + 1] + y[i0 - 1] - 2*y[i0])/4.
                c0 = y[i0] - y_val

                def Q(l):
                    return c0 + c1*l + c2*l**2.

                def Qp(l):
                    return c1 + 2.0*c2*l

                # two steps of Newton
                for i in range(2):
                    lam -= Q(lam)/Qp(lam)

                x = x[i0] + lam*(x[i0] - x[i0-1])
                return np.array([x])
            else:
                return np.array([float("nan")])
            
        sign_prev = sign
    return np.array([float("nan")])
    
def emp_log_f(x,xp):
    r=x/xp
    if r <= 0:
        return float("nan")
    return math.log(r)

def emp_log(c,ndisp=1):
    return [ emp_log_f(c[i],c[i+ndisp])/ndisp for i in range(len(c)-ndisp) ]

def emp_cosh_f(x,xp,xpp):
    r=(x+xpp)/xp/2.0
    if r <= 1:
        return float("nan")
    return math.acosh(r)

def emp_cosh(c):
    return [ emp_cosh_f(c[i],c[i+1],c[i+2]) for i in range(len(c)-2) ]

def fold(c):
    return np.array([ (c[i] + c[-i])*0.5 for i in range(len(c)//2+1) ])

def afold(c):
    return np.array([ (c[i] - c[-i])*0.5 for i in range(len(c)//2+1) ])

def unfold(c):
    return np.array(c.tolist() + [ c[i] for i in reversed(range(1,len(c)-1)) ])

def cshift(a,i):
    r=np.array(a[i:].tolist() + a[0:i].tolist())
    return r

def shift(a,i):
    if i > 0:
        r=np.array(a[i:].tolist() + (float("nan")*a[0:i]).tolist())
    else:
        r=np.array((float("nan")*a[i:]).tolist() + a[0:i].tolist())
    return r

# foldsum = 2fold for t!=0 leaving t==0 untouched
def foldsum(a):
    t=a[1:]
    r=(t + t[::-1])[:len(a)//2]
    return np.array([ a[0] ] + r.tolist() )

def partial_sum(c):
    return np.array([ sum(c[0:i+1]) for i in range(len(c)) ])

def rpartial_sum(c):
    return np.array([ sum(c[i+1:]) for i in range(len(c)) ])


def DT(t,T):
    while t<-T//2:
        t+=T
    while t>=T//2:
        t-=T
    return t

fno=sys.argv[1]

res=jks.resamples(fno)

for i in range(Nop):
    otag=sys.argv[2+2*i]
    oop=sys.argv[3+2*i]
    print("Adding %d/%d: %s = %s" % (i+1,Nop,otag,oop))
    op=eval("lambda r: %s" % oop)
    jk=res.apply(op)
    res.add(otag,jk)

#if "JKS_NO_COMPRESS" not in os.environ:
#   res.compress()

res.save(fno)

