#!/usr/bin/env python3
import sys
import argparse
import random
from functools import partial
from multiprocessing import Pool, cpu_count
import numpy as np
import pandas as pd
from scipy.stats import chi2,norm
from hapflk import popgen, InputOutput, hapflkadapt



def eigen_contrib_diallelic(p,w,un,invF,D,Q):
    '''
    Computes orthogonalized contributions of each component to the FLK
    statistic, using the spectral decomposition of the variance
    covariance matrix V(p) V=cste*F=cste*Q*D*Q.T where Q is the matrix
    of eigen vectors and D the corresponding diagonal matrix with
    eigen values on the diagonal.
    
    Parameters:
    ---------------
    
    p : vector of population allele frequencies at a locus
    
    Return Value:
    -----------------
    a tuple of :
    - p0hat : estimated allele frequency in the ancestral population
    - Z : contributions of each PC to the FLK statistic
    - cste : multiplying constant to the test
    the value of FLK is = C*sum(L^2)
    '''
    p0hat = np.dot(w.T,p)
    cste = ((1 - (1/(np.dot(un,np.dot(invF,un.T)))))/(p0hat*(1-p0hat)))
    Z=np.dot((p-(p0hat*un)),Q)
    Z=np.dot(Z,np.diag(np.sqrt(1/D)))
    Z *= np.sqrt(cste)
    return p0hat,Z,cste


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', dest='input_file', help='Input File Path')
    parser.add_argument('-o', dest='prefix',help='Output File prefix', default = 'flkfreq')
    parser.add_argument('--popcols', dest='transpose', help='Populations as rows', action='store_true', default=False)
    parser.add_argument('--delim', dest='delim', help='Delimiter', default=';')
    parser.add_argument('--ncpu', dest='ncpu', help='use mutliple cpus', default = cpu_count(), type=int)

    if len(sys.argv)<2:
        parser.print_help()
        sys.exit(1)
    opts = parser.parse_args(sys.argv[1:])

    fdat = pd.read_csv(opts.input_file,delimiter=opts.delim)
    if opts.transpose:
        popnames = fdat.columns.tolist()
        snpnames = fdat.index.tolist()
    else:
        popnames = fdat.index.tolist()
        snpnames = fdat.columns.tolist()
    frq = np.array(fdat)
    subset = np.isnan(frq)
    frq [ subset ] = np.ma.masked
    
    if opts.transpose :
        frq = frq.T

    npops , nsnps = frq.shape
    
    pfx = opts.prefix
    hzy = popgen.heterozygosity( frq)
    rey = popgen.reynolds( frq)
    kinship = popgen.popKinship_new(rey, popnames, None, fprefix=pfx, hzy=hzy, dump_tree=False)
        
    ## set up kinship matrix
    ff = popgen.FLK_test(kinship)
    flkfunc = partial( eigen_contrib_diallelic, w=ff.w,un = ff.un ,invF = ff.invF, D = ff.D, Q=ff.Q)

    with Pool( opts.ncpu) as p:
        res = p.map( flkfunc, [ frq[:,s] for s in range( frq.shape[1])])
    pzero=np.array([ x[0] for x in res])
    sub=(pzero>0.05)&(pzero<0.95)
    myf=frq[:,sub]
    m=pzero[sub]
    s=np.sqrt(m*(1-m))
    myfstd=(myf-m)/s
    covkin=np.cov(myfstd,rowvar=True)
    ## Add a small term on the diagonal to ensure stable inverse
    val=np.trace(covkin)/npops
    assert (val >0)
    covkin=covkin+np.diag(np.ones(npops)*0.01*val)
    with open(pfx+'_cov.txt','w') as fcov:
        for i,pi in enumerate(popnames):
            tw=[pi]
            for j,pj in enumerate(popnames):
                tw.append(str(covkin[i,j]))
            print(' '.join(tw),file=fcov)
        kinship=covkin
   

    ## Compute FLK
    ff = popgen.FLK_test(kinship)
    flkfunc = partial( eigen_contrib_diallelic, w=ff.w,un = ff.un ,invF = ff.invF, D = ff.D, Q=ff.Q)
    with Pool( opts.ncpu) as p:
        res = p.map( flkfunc, [ frq[:,s] for s in range( frq.shape[1])])
        with open(opts.prefix+'.flk','w') as fout:
            print( 'snp', 'pzero','FLK', 'df','pval', file=fout)
            pzero = [ x[0] for x in res]
            flk = [ sum( [v**2 for v in x[1]]) for x in res]
            pval = [ ( ( x[0] < 0.05) or (x[0] > 0.95)) and np.nan or chi2.sf(x[1], ff.dimension-1) for x in zip(pzero, flk)]
            for i_s in range( nsnps):
                print( snpnames[i_s], pzero[i_s], flk[i_s], ff.dimension-1, pval[i_s],file=fout)
    return

if __name__=='__main__':
    main()
