#!python
#
#    JKS - Measurement database system
#    Copyright (C) 2013-2024  Christoph Lehner (christoph.lehner@ur.de, https://github.com/lehner/jks)
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License along
#    with this program; if not, write to the Free Software Foundation, Inc.,
#    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import jks, sys, os, math, glob, time
import numpy as np

minargs=len(sys.argv)-7
if minargs < 0 or minargs % 3 != 0:
    print("%s inout tag1 range(s)1 fnc1 [ tag2 range(s)2 fnc2 ... ] guess fittag" % sys.argv[0])
    sys.exit(0)

n=1 + minargs//3

default_fit_tol = 1e-8
if "JKS_FIT_TOL" in os.environ:
    default_fit_tol = float(os.environ["JKS_FIT_TOL"])

fni=sys.argv[1]
tag=[ sys.argv[2+i*3] for i in range(n) ]

rr=[ list(eval(sys.argv[3+i*3])) for i in range(n) ]

# verify range(s)
for i in range(n):
    if type(rr[i][0]) != type([]):
        rr[i]=[ rr[i] ]
    assert( len(rr[i]) == len(rr[0]) )
nranges=len(rr[0])

#print [ "lambda x,p,r: %s" % sys.argv[4+i*3] for i in range(n) ]
#sys.exit(0)
fnc=[ eval("lambda x,p,r: %s" % sys.argv[4+i*3]) for i in range(n) ]
guess=eval(sys.argv[2+n*3])
original_guess=guess
fno=fni
fittag=sys.argv[3+n*3]

# TODO: change this, why first import?
jks2=jks.resamples(fni)
res=jks.resamples()
res.take(jks2,"*")

# do we have priors?
#for i in range(len(guess)):
#    if type(guess[i]) == tuple:
#        print guess[i]
#        guess[i]=


# prepare fit
idx={}
def mkinp(r):
    ret=np.array([])
    for k in tag:
        idx[k]=len(ret)
        ret=np.append(ret,r[k])
    return ret

jk=jks2.apply(mkinp)
jks2.add("#.#fit#.#",jk)

v_cv=jk.tcov()

v_cv=[ [ v_cv[i][i] if i==j else 0.0 for i in range(len(v_cv)) ] for j in range(len(v_cv)) ]

# TODO: create fnccomb, rrcomb
rrcomb=[]
for i in range(nranges):
    rc=[]
    for j in range(n):
        rc = rc + [ idx[tag[j]] + x for x in rr[j][i] ]
    rrcomb.append(rc)

param_keys={}
const_keys={}
def fnccomb(x,p):
    r=dict([ (k,p[param_keys[k][0]:param_keys[k][1]]) for k in param_keys ])
    for k in const_keys:
        r[k] = const_keys[k]

    for i in reversed(range(n)):
        if idx[tag[i]] <= x:
            x=x-idx[tag[i]]
            return fnc[i](x,p,r)
            
    assert(0)

def scan():
    for r in rrcomb:
        for x in r:
            fnccomb(x,guess)

def update_keys(r,p):
    #print "Update from " + str(p)
    for tag in param_keys:
        p = p[:param_keys[tag][0]] + r[tag].tolist() + p[param_keys[tag][1]:]
    #print "Update to " + str(p)
    return p
        
def need_key():
    global guess
    try:
        scan()
    except KeyError as e:
        tag=eval(str(e))
        if tag in jks2.keys():
            jk=jks2.get(tag)
            tag_guess=jk.mean()
            tag_cv=jk.tcov()
            if np.linalg.norm(tag_cv) == 0.0:
                print("Constant %s = %s" % (tag,str(tag_guess)))
                const_keys[tag]=tag_guess
            else:
                param_keys[tag]=(len(guess),len(guess)+len(tag_guess))
                guess = guess + tag_guess.tolist()
                print("Add %s as p[%s], guess = %s" % (tag,str(param_keys[tag]),str(tag_guess)))
            return True
        else:
            print("Unknown key: %s" % tag)
            sys.exit(1)
    return False

# Find out what r[tag] we need
while need_key():
    pass

#eps_estimate_hessians=1e-6
eps_estimate_hessians=1e-7
eps_require_appx_precision=2e-2

def qf(rr0,maxit,tol):
    verb=False
    vfail=[ float("nan") for x in guess ]
    res=jks.qfit(v_mn,v_cv,rr0,fnccomb,guess,verbose=verb,maxiter=maxit,tolerance=tol,estimate_hessians=eps_estimate_hessians,freeze_index=len(original_guess))
    #print res
    if not res["success"] and maxit > 10:
        print("First round did not converge", rr0, tol)
        improved_guess=res["val"] + guess[len(original_guess):]
        print("Guess")
        print(guess)
        print(" --> ")
        print(improved_guess)
        res=jks.qfit(v_mn,v_cv,rr0,fnccomb,improved_guess,verbose=verb,maxiter=maxit,tolerance=tol,estimate_hessians=eps_estimate_hessians,freeze_index=len(original_guess))
        if not res["success"]:
            print("Did not converge")
            assert(0)
            res={ "val" : vfail, "hessian_PP" : vfail, "hessian_PM" : vfail }
    return (res["val"],res["hessian_PP"],res["hessian_PM"])

def fit(c, maxit=10000, tol=default_fit_tol):
    global v_mn
    v_mn=c
    ret=[]

    for i,r in enumerate(rrcomb):
        ret.append(qf(r,maxit,tol))
    return ret

def sfit(c, maxit=10000, tol=default_fit_tol):
    global v_mn
    v_mn=c
    ret=[]

    for i,r in enumerate(rrcomb):
        ret = ret + qf(r,maxit,tol)[0]
    return ret

def sel(c,j,l):
    return [ c[i] if i in rr[j][l] else float("nan") for i in range(len(c)) ]

def dfit(c,rs):
    global d_mean, d_corrMat, d_fit
    dfreeze=np.array(update_keys(rs,guess)) - np.array(guess)
    dm=np.array(c) - np.array(d_mean)
    ret=[]
    for i,r in enumerate(rrcomb):
        fres=np.array(d_fit[i][0]) + np.dot(delA[i],dm[r]) + np.dot(delB[i],dfreeze)
        ret = ret + fres.tolist()
    return ret

d_mean=jk.mean()
t0=time.time()
d_fit=fit(d_mean)
t1=time.time()

# D[f(m,p),p_i] == 0
# f(m+eps dm,p+eps dp)=f(m,p) + eps (D[f(m,p),m_i] dm_i + D[f(m,p),p_i] dp_i)
# D[f(m+eps dm,p+eps dp),p_j] =
#  eps (D[f(m,p),m_i,p_j] dm_i + D[f(m,p),p_i,p_j] dp_i) != 0
# P_{ij} = D[f(m,p),p_i,p_j]
# X_{ij} = D[f(m,p),p_i,m_j] 
# X dm = - P dp = - Pv dpv - Pfreeze dpfreeze
# - Pvinv X dm - Pvinv Pfreeze dpfreeze = dpv
# delA dm + delB dpfreeze = dpv
def minv(_A):
    A=np.array(_A)
    n=len(A)
    B=np.identity(n,dtype=np.longdouble)
    for i in range(n):
        lam=1.0 / A[i,i]
        A[i]*=lam
        B[i]*=lam
        for j in range(i+1,n):
            lam=A[j,i]
            A[j]-=lam*A[i]
            B[j]-=lam*B[i]
        #print A, B
    for i in reversed(range(n)):
        for j in range(0,i):
            lam=A[j,i]
            A[j]-=lam*A[i]
            B[j]-=lam*B[i]
        #print A, B
    return B

Pv=[ d_fit[i][1][0:len(original_guess),0:len(original_guess)] for i in range(len(rrcomb)) ]
Pvinv=[ minv(Pv[i]) for i in range(len(rrcomb)) ]

#print np.dot(Pvinv[0],Pv[0])
#sys.exit(0)
Pfreeze=[ d_fit[i][1][0:len(original_guess)] for i in range(len(rrcomb)) ]
delA=[ -np.dot(Pvinv[i],d_fit[i][2][0:len(original_guess)]) for i in range(len(rrcomb)) ]
delB=[ -np.dot(Pvinv[i],Pfreeze[i]) for i in range(len(rrcomb)) ]


print("Fit central values (in %g s):" % (t1-t0))
for i,r in enumerate(rrcomb):
    print("%s: %s -> %s" % (str(r),str(guess),str(d_fit[i][0])))

for j in range(len(jk.blocks)):
    break
    cmpA=filter(lambda x: not np.isnan(x), jk.blocks[j].tolist())
    cmpB=filter(lambda x: not np.isnan(x), d_mean.tolist())
    if len(cmpA) != 0 and len(cmpA) == len(cmpB) and cmpA != cmpB:
        cnt=np.ravel([ r[0] for r in d_fit ])

        eps=1
        d_fluct_P = (jk.blocks[j] - d_mean)*eps + d_mean
        rs0=jks2.cut(0)
        rs1=jks2.cut(j)

        guess_pre=guess
        rsP=dict([ (k,(rs1[k]-rs0[k])*eps+rs0[k]) for k in rs0 ])
        guess=update_keys(rsP,guess)
        refP=np.ravel([ r[0] for r in fit(d_fluct_P) ]) - cnt

        rsM=dict([ (k,-(rs1[k]-rs0[k])*eps+rs0[k]) for k in rs0 ])
        guess=update_keys(rsM,guess)
        d_fluct_M = -(jk.blocks[j] - d_mean)*eps + d_mean
        refM=np.ravel([ r[0] for r in fit(d_fluct_M) ]) - cnt
        ref=(refP - refM) / 2.
        
        guess=guess_pre
        tst=dfit(d_fluct_P,rsP) - cnt

        #print "BLOCK ", j, jk.tags[j]
        #print jk.blocks[j]
        #print d_mean
        #print "INFO"
        #print ref
        #print tst

        eps_array = np.abs((ref - tst) / ref)
        eps=np.max(eps_array)

        print("Jackknife fits using Hessian accuracy: %g" % (eps))
        #if eps > eps_require_appx_precision:
        #    print "Poor accuracy"
        #    assert(0)
        #break

t0=time.time()
jk=jks2.apply(lambda r: sfit(r["#.#fit#.#"]))
t1=time.time()
print("Total fitting took %g s" % (t1-t0))

res.add(fittag,jk)

for i in range(n):
    for j in range(nranges):
        jk=jks2.apply(lambda r: sel(r[tag[i]],i,j))
        res.add("%s.%s.input.%d" % (fittag,tag[i],j),jk)
        # TODO: create .output with the fit function values at the points

#if "JKS_NO_COMPRESS" not in os.environ:
#   res.compress(verbose=True)

res.save(fno)
