#!python
#
#    JKS - Measurement database system
#    Copyright (C) 2013-2024  Christoph Lehner (christoph.lehner@ur.de, https://github.com/lehner/jks)
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License along
#    with this program; if not, write to the Free Software Foundation, Inc.,
#    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import jks, sys, os, math, glob, numpy, fnmatch
if len(sys.argv) < 4 or len(sys.argv) % 2 != 0:
    print("%s fn etag1 pat1 [etag2 pat2 ...]" % sys.argv[0])
    sys.exit(0)

N=(len(sys.argv) - 2) // 2

fn=sys.argv[1]
etag=[sys.argv[2+2*i] for i in range(N) ]
pat=[sys.argv[3+2*i] for i in range(N) ]

jk=jks.resamples()

def get_tag(f,idx,ll):
    if len(pat[ll])-idx==1:
        return f[idx:]
    else:
        return f[idx:-len(pat[ll])+idx+1]


def mm(ln,i):
    a=ln.split(" ")
    assert(int(a[0])==i)
    return float(a[1])

def proj_real(a):
    return [ x.real for x in a ]

def proj_imag(a):
    return [ x.imag for x in a ]

def load(fn):
    global N,citag,cotag
    print("Loading %s" % fn)
    t=jks.corrIO.reader(fn).tags
    print("Done")
    res={}

    #if N == 0:
    citag=set([])
    for tag in t.keys():
        citag.add(tag)
    citag=list(citag)
    cotag=citag
    N=len(citag)

    for i in range(N):
        res[cotag[i] + ".r"] = proj_real(t[citag[i]])
        res[cotag[i] + ".i"] = proj_imag(t[citag[i]])
    return res

fns=[]
gtag = {}

for ll in range(N):
    
    fns_ll=glob.glob(pat[ll])

    idx=pat[ll].index('*')
    fmt=pat[ll].replace("*","%d")

    # check that projecting config numbers to integers gives a unique mapping
    tags=[ int(get_tag(f,idx,ll)) for f in fns_ll ]
    assert(len(list(set(tags))) == len(fns_ll))
    ida=numpy.argsort(tags)
    fns_ll=(numpy.array(fns_ll)[ida]).tolist()

    # append file names
    fns = fns + fns_ll

    for f in fns_ll:
        gtag[f] = "%s-%8.8d" % (etag[ll],int(get_tag(f,idx,ll)))

   
sorted_fns = sorted(fns, key=lambda f: gtag[f])

data=dict([ (f,load(f)) for f in fns ])

if len(fns) == 0:
    print("Attention: no file loaded for", sys.argv)
    sys.exit(1)

for ctag in data[fns[0]].keys():
    if not all([ ctag in data[f] for f in data ]):
        continue
    m=jks.measurements(dict([ (gtag[f],data[f][ctag]) for f in sorted_fns if ctag in data[f] ]))
    jk2=m.jackknife().prepare(lambda mi: mi.mean())
    jk.add(ctag,jk2)

jk.save(fn)



