#!python
#
#    JKS - Measurement database system
#    Copyright (C) 2013-2024  Christoph Lehner (christoph.lehner@ur.de, https://github.com/lehner/jks)
#
#    This program is free software; you can redistribute it and/or modify
#    it under the terms of the GNU General Public License as published by
#    the Free Software Foundation; either version 2 of the License, or
#    (at your option) any later version.
#
#    This program is distributed in the hope that it will be useful,
#    but WITHOUT ANY WARRANTY; without even the implied warranty of
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#    GNU General Public License for more details.
#
#    You should have received a copy of the GNU General Public License along
#    with this program; if not, write to the Free Software Foundation, Inc.,
#    51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#
import jks, sys, os, math, glob, numpy
if len(sys.argv) < 5 or len(sys.argv) % 2 == 0:
    print("%s fn ctag etag1 pat1 [etag2 pat2 ...]" % sys.argv[0])
    sys.exit(0)

fn=sys.argv[1]
ctag=sys.argv[2]

n=(len(sys.argv) - 3) // 2
etag=[sys.argv[3+2*i] for i in range(n)]
pat=[sys.argv[4+2*i] for i in range(n)]

jks2=jks.resamples()

def get_tag(f,idx,ii):
    if len(pat[ii])-idx==1:
        return f[idx:]
    else:
        return f[idx:-len(pat[ii])+idx+1]


def mm(ln,i):
    a=list(filter(lambda x: x!="", ln.split(" ")))
    assert(int(a[0])==i)
    return float(a[1])

def load(fn):
    lns=list(filter(lambda x: x!="" and x[0]!="T",open(fn).read().split("\n")))
    rows=len(lns[0].split(" "))
    assert(rows >= 2)
    return [ mm(lns[i],i) for i in range(len(lns)) ]


fns=[]
gtag = {}

for ll in range(n):
    
    fns_ll=glob.glob(pat[ll])

    idx=pat[ll].index('*')
    fmt=pat[ll].replace("*","%d")

    # check that projecting config numbers to integers gives a unique mapping
    tags=[ int(get_tag(f,idx,ll)) for f in fns_ll ]
    assert(len(list(set(tags))) == len(fns_ll))
    ida=numpy.argsort(tags)
    fns_ll=(numpy.array(fns_ll)[ida]).tolist()

    # append file names
    fns = fns + fns_ll

    for f in fns_ll:
        gtag[f] = "%s-%8.8d" % (etag[ll],int(get_tag(f,idx,ll)))

sorted_fns = sorted(fns, key=lambda f: gtag[f])

data=dict([ (f,load(f)) for f in fns ])

min_len=min([len(data[f]) for f in fns ])
data=dict([ (f,data[f][0:min_len]) for f in fns ])

m=jks.measurements(dict([ (gtag[f],data[f]) for f in sorted_fns ]))
jk=m.jackknife().prepare(lambda mi: mi.mean())
jks2.add(ctag,jk)
jks2.save(fn)



