#!/usr/bin/env python3
import sys
import datetime as dt
import numpy as np
from hapflk import hapflk, utils, popgen

 
def main():
    counter = utils.Stepper()
    myopts = hapflk.HapFLK.get_opts(sys.argv)
    mymod = hapflk.HapFLK()

    counter.new("Reading In Data")
    data = mymod.read_input( myopts)
    
    counter.new("Computing Allele Frequencies")
    res = data.compute_pop_frq()
    frq = res['freqs']
    pops = res['pops']
    
    counter.new("Setting up SNP subset")
    minfrq = np.min( frq, axis= 0)
    maxfrq = np.max( frq, axis= 0)
    ## Keep SNPs where at least one pop has freq in ]a, 1-a[ with
    ## a small (0.1)
    a = 0.1
    subset = (maxfrq > a) & (minfrq < (1-a))
    nsnp = np.sum( subset)
    print( "SNP subset : %d/%d"%(nsnp,len(maxfrq)))
    subfreqs = frq[:,subset]
    pbinom = np.min( (1,float(myopts.reysnps)/nsnp))
    rand_subset = np.array( np.random.binomial( 1,pbinom,nsnp),dtype=bool)
    reynolds = popgen.reynolds( subfreqs[:,rand_subset])
    hzy = popgen.heterozygosity( subfreqs[:,rand_subset])
    
    kinship  = popgen.popKinship_new( reynolds, pops, myopts.outgroup,
                                      fprefix = myopts.prefix,
                                      keep_outgroup = myopts.keepOG,
                                      hzy = hzy,
                                        dump_tree = myopts.tree)
    if myopts.covkin:
        counter.new('Computing Covariance Matrix')
        ff=popgen.FLK_test(kinship)
        ## calculate hat(p0)
        pzero=np.array([ff.eigen_contrib(subfreqs[:,s])[0] for s in range(nsnp)])
        sub=(pzero>0.05)&(pzero<0.95)
        ## calculate empirical covariance matrix
        myf=subfreqs[:,sub]
        m=pzero[sub]
        s=np.sqrt(m*(1-m))
        myfstd=(myf-m)/s
        covkin=np.cov(myfstd,rowvar=True)
        ## Add a small term on the diagonal to ensure stable inverse
        val=np.trace(covkin)/kinship.shape[0]
        assert (val >0)
        covkin=covkin+np.diag(np.ones(kinship.shape[0])*0.01*val)
        with open(myopts.prefix+'_cov.txt','w') as fcov:
            for i,pi in enumerate(pops):
                tw=[pi]
                for j,pj in enumerate(pops):
                    tw.append(str(covkin[i,j]))
                print(' '.join(tw),file=fcov)
    counter.end()
    
if __name__=='__main__':
    main()
