#!/usr/bin/env python3
import sys
import argparse
import random
from functools import partial
from multiprocessing import Pool, cpu_count
import numpy as np
from scipy.stats import chi2,norm
from hapflk import popgen, InputOutput, hapflkadapt


sync_codes=['A','T','C','G','N','DEL']

def get_line(fic, position):
    start_position = position
    while True:
        fic.seek(position)
        symbol = fic.read(1)
        if symbol == '\n' and fic.tell() != start_position + 1:  # get line when reaches \n, and when \n was not the first match
            return fic.readline().rstrip()
        elif position == 0:                                       # get symbol + rest of the line when reaches start of file
            return symbol + fic.readline().rstrip()
        else:
            position -= 1

def get_random_lines(fic, num_of_lines):
    fic.seek(0, 2)
    last_char = fic.tell() - 1  # skipping EOF position with -1
    lines = []
    pos = []
    for _ in range(num_of_lines):
        pos.append(random.randint(0, last_char))
    for position in np.sort(pos):
        lines.append(get_line(fic, position))
    return lines



def parse_sync_line(ligne):
    '''
    Reads in line of sync file and returns allele frequencies predictions 

    Input example: 
    NC_010443.5 112 C  0:0:23:0:0:0    0:4:18:0:0:0    0:8:13:0:0:0  ...

    Output:
    a = numpy.array of floats
    '''
    
    ## split pools
    buf = ligne.split()
    snp_info=buf[:3]
    ref_allele = buf[2].upper()
    ref_position = sync_codes.index(ref_allele)
    nobs = 0.0
    nrefs = 0.0
    nrefs_pool = []
    nobs_pool = []
    for poolobs in buf[3:]:
        n=[float(x) for x in poolobs.split(':')]
        nobs_pool.append( sum(n))
        nrefs_pool.append( n[ ref_position])
    nobs = sum(nobs_pool)
    nrefs = sum(nrefs_pool)

    if nrefs == 0:
        return ( nobs, np.zeros(len(nrefs_pool),dtype=float),np.zeros(len(nrefs_pool),dtype=float))
    if nrefs == nobs:
        return ( nobs, np.ones(len(nrefs_pool),dtype=float),np.ones(len(nrefs_pool), dtype=float))

    pbar = nrefs/nobs

    if pbar < 0.5:
        a_prior = 1
        b_prior = ( 1 - pbar)/pbar
    else:
        a_prior = pbar/( 1 - pbar)
        b_prior = 1

    fpool_pref = []
    fpool_raw = []
    for poolobs in zip( nrefs_pool, nobs_pool):
        num = a_prior + poolobs[0]
        den = num + b_prior + (poolobs[1] - poolobs[0])
        fpool_pref.append( num / den)
        try:
            fpool_raw.append( poolobs[0] / poolobs[1])
        except ZeroDivisionError:
            fpool_raw.append( np.nan )
    return ( nobs, np.array(fpool_pref), np.array(fpool_raw))


def eigen_contrib_diallelic(p,w,un,invF,D,Q):
    '''
    Computes orthogonalized contributions of each component to the FLK
    statistic, using the spectral decomposition of the variance
    covariance matrix V(p) V=cste*F=cste*Q*D*Q.T where Q is the matrix
    of eigen vectors and D the corresponding diagonal matrix with
    eigen values on the diagonal.
    
    Parameters:
    ---------------
    
    p : vector of population allele frequencies at a locus
    
    Return Value:
    -----------------
    a tuple of :
    - p0hat : estimated allele frequency in the ancestral population
    - Z : contributions of each PC to the FLK statistic
    - cste : multiplying constant to the test
    the value of FLK is = C*sum(L^2)
    '''
    p0hat = np.dot(w.T,p)
    cste = ((1 - (1/(np.dot(un,np.dot(invF,un.T)))))/(p0hat*(1-p0hat)))
    Z=np.dot((p-(p0hat*un)),Q)
    Z=np.dot(Z,np.diag(np.sqrt(1/D)))
    Z *= np.sqrt(cste)
    return p0hat,Z,cste

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--sync', dest="filename",help='Genotype input file name (Sync format)')
    parser.add_argument('-p','--prefix',dest='prefix',help='prefix for output files',default='poolflk')
    parser.add_argument('--ncpu', dest='ncpu', help='use mutliple cpus', default = cpu_count(), type=int)
    parser.add_argument('--kinship',help='Read population kinship from file (if None, kinship is estimated)',metavar='FILE',default=None)
    parser.add_argument('-L',dest='kinsnps', help='Number of SNPs to estimate kinship',default=1e5,type=int)
    parser.add_argument('--popnames',help='Read population names from file (one per line)', metavar='FILE',default=None)
    parser.add_argument('--remove', help='Remove populations (list of indices separated by commas)',default='')
    parser.add_argument('--minDP', help='Consider SNPs with dp > ( MINDP * npools ) for kinship', default=30, type=int)
    parser.add_argument('--no-FLK',help='Do not compute FLK, only FLKadapt', dest='flk', default=True, action='store_false')
                        
    InputOutput.populate_covariate_args(parser)
    
    if len(sys.argv)<2:
        parser.print_help()
        sys.exit(1)
    opts = parser.parse_args(sys.argv[1:])

    npools = len( open(opts.filename).readline().split()) -3
    
    if opts.popnames is None:
        popnames = ['pop'+str(i) for i in range(npools)]
    else:
        popnames=[]
        with open(opts.popnames) as f:
            for ligne in f:
                popnames.append( ligne.strip())

    if len( opts.remove) > 0:
        pop2rm = [int(v)-1 for v in opts.remove.split(',')]
        print( 'Dropping populations:\n','\n'.join([popnames[i] for i in pop2rm]))
    else:
        pop2rm = []
    popidx = [ i for i in range( len( popnames)) if i not in pop2rm]
    mypopnames =  [popnames[i] for i in popidx]

    if opts.covar_file != '':
        covdata = InputOutput.get_covariates_matrix( opts.covar_file, opts.cov, opts.qcov, mypopnames)
        for f in  covdata['covariables']:
            myprefix = opts.prefix+'_'+f
            with open( myprefix + '.flkadapt','w') as fout_adapt:
                print('chr','pos','refa', 'DP','pzero.null', 'converge.null', 'bf',  'L0', 'L1', 'LRT', 'pvalue', 'pzero','beta', 'converge',  file=fout_adapt)
    else:
        covdata = None

    ## Kinship calculations
    with open(opts.filename) as sync_file:
        ligne = sync_file.readline()
        buf=ligne.split()
        npops = len(mypopnames)
        print(npools,'pools in file', opts.filename)
      
        if opts.kinship:
            kinship = popgen.popKinship_fromFile(opts.kinship,mypopnames)
        else:
            print('Will sample',opts.kinsnps,'SNPs for kinship')
            rndsnp = get_random_lines(sync_file, int(opts.kinsnps))
            snpfreqs=[]
            rawsnpfreqs=[]
            for ligne in rndsnp:
                dp, freq, rawfreq = parse_sync_line( ligne)
                if dp > (opts.minDP * npools)  and ~np.isnan( rawfreq).any():
                    snpfreqs.append(freq[popidx])
                    rawsnpfreqs.append(rawfreq[popidx])
            snpfreqs = np.array(snpfreqs).T
            rawsnpfreqs = np.array(rawsnpfreqs).T
            print("Final SNP sample for kinship with DP > ", opts.minDP*npools,":",snpfreqs.shape[1])
            
            print('Computing Kinship from shrinked frequencies')
            pfx = opts.prefix+'_hrc'
            hzy = popgen.heterozygosity(snpfreqs)
            rey = popgen.reynolds(snpfreqs)
            kinship = popgen.popKinship_new(rey, mypopnames, None, fprefix=pfx, hzy=hzy, dump_tree=False)
            ### covariance mat
            ff = popgen.FLK_test(kinship)
            flkfunc = partial( eigen_contrib_diallelic, w=ff.w,un = ff.un ,invF = ff.invF, D = ff.D, Q=ff.Q)
            with Pool( opts.ncpu) as p:
                res = p.map( flkfunc, [ snpfreqs[:,s] for s in range(snpfreqs.shape[1])])
            pzero=np.array([ x[0] for x in res])
            sub=(pzero>0.05)&(pzero<0.95)
            myf=snpfreqs[:,sub]
            m=pzero[sub]
            s=np.sqrt(m*(1-m))
            myfstd=(myf-m)/s
            covkin=np.cov(myfstd,rowvar=True)
            ## Add a small term on the diagonal to ensure stable inverse
            val=np.trace(covkin)/npops
            assert (val >0)
            covkin=covkin+np.diag(np.ones(npops)*0.01*val)
            with open(pfx+'_cov.txt','w') as fcov:
                for i,pi in enumerate(mypopnames):
                    tw=[pi]
                    for j,pj in enumerate(mypopnames):
                        tw.append(str(covkin[i,j]))
                    print(' '.join(tw),file=fcov)
                kinship=covkin
    
            print('Computing Kinship from raw frequencies')
            pfx = opts.prefix+'_raw'
            hzy = popgen.heterozygosity(rawsnpfreqs)
            rey = popgen.reynolds(rawsnpfreqs)
            kinship_raw = popgen.popKinship_new(rey, mypopnames, None, fprefix=pfx, hzy=hzy, dump_tree=False)
            ### covariance mat
            ff = popgen.FLK_test(kinship)
            flkfunc = partial( eigen_contrib_diallelic, w=ff.w,un = ff.un ,invF = ff.invF, D = ff.D, Q=ff.Q)
            with Pool( opts.ncpu) as p:
                res = p.map( flkfunc, [ rawsnpfreqs[:,s] for s in range(rawsnpfreqs.shape[1])])
            pzero=np.array([ x[0] for x in res])
            sub=(pzero>0.05)&(pzero<0.95)
            myf=rawsnpfreqs[:,sub]
            m=pzero[sub]
            s=np.sqrt(m*(1-m))
            myfstd=(myf-m)/s
            covkin_raw=np.cov(myfstd,rowvar=True)
            ## Add a small term on the diagonal to ensure stable inverse
            val=np.trace(covkin_raw)/npops
            assert (val >0)
            covkin_raw=covkin_raw+np.diag(np.ones(npops)*0.01*val)
            with open(pfx+'_cov.txt','w') as fcov:
                for i,pi in enumerate(mypopnames):
                    tw=[pi]
                    for j,pj in enumerate(mypopnames):
                        tw.append(str(covkin_raw[i,j]))
                    print(' '.join(tw),file=fcov)
                kinship_raw=covkin_raw
    
    ### FLK calculations
    ff = popgen.FLK_test(kinship_raw)
    flkfunc = partial( eigen_contrib_diallelic, w=ff.w,un = ff.un ,invF = ff.invF, D = ff.D, Q=ff.Q)
    with open(opts.filename) as sync_file:
        with Pool( opts.ncpu) as p:
            with open(opts.prefix+'.flk','w') as fout:
                with open(opts.prefix+'.frq','w') as frq_file:
                    print( 'chr', 'pos','refa','DP', 'est', *mypopnames, file=frq_file)
                    print( 'chr', 'pos', 'refa', 'DP','pzero','FLK', 'df','pval', file=fout)
                    flk_tasks = []
                    flkadapt_tasks = []
                    snp_info = []
                    batch_size = 100000
                    for i, ligne in enumerate(sync_file):
                        buf = ligne.split()
                        dp, freq, rawfreq = parse_sync_line( ligne)
                        flkadapt_tasks.append(freq[popidx])
                        flk_tasks.append(rawfreq[popidx])
                        snp_info.append( buf[:3]+[dp] )
                        if (i+1)%batch_size == 0:
                            print('Computing FLK : %10d'%(i+1))
                            sys.stdout.flush()
                            res = p.map( flkfunc, flk_tasks)
                            pzero = [ x[0] for x in res]
                            flk = [ sum( [v**2 for v in x[1]]) for x in res]
                            pval = [ ( ( x[0] < 0.05) or (x[0] > 0.95)) and np.nan or chi2.sf(x[1], ff.dimension-1) for x in zip(pzero, flk)]
                            ## write output
                            for i_s, info in enumerate(snp_info):
                                print(*info, pzero[i_s], flk[i_s], ff.dimension-1, pval[i_s],file=fout)
                                print(*info, 'raw',*flk_tasks[i_s], file=frq_file)
                                print(*info, 'hrc',*flkadapt_tasks[i_s], file=frq_file)
                            if covdata:
                                print('Computing FLKadapt : %10d'%(i+1))
                                sys.stdout.flush()
                                freqs = np.array(flkadapt_tasks).T
                                for f in covdata['covariables']:
                                    myprefix = opts.prefix+'_'+f
                                    print('Covariate ', f)
                                    sys.stdout.flush()
                                    with hapflkadapt.FLKadapt(kinship, covdata['DesignMatrices'][f], prefix = myprefix, ncpu=opts.ncpu, stdize=True) as flkadapt:
                                        flkadapt.correct_lrt = lambda x : flkadapt.df
                                        bflka = flkadapt.bayesFLKadapt( freqs)
                                        flka = flkadapt.FLKadapt( freqs )
                                    with open( myprefix + '.flkadapt','a') as fout_adapt:
                                        for i,d in enumerate(flka):
                                            info = snp_info[i]
                                            try:
                                                print( *info, d['H0']['pzero'],d['H0']['pzsucc'], bflka[i],  \
                                                       d['H0']['llik'], d['H1']['llik'], d['lrt'],  d['pvalue'], \
                                                       d['H1']['pzero'], *d['H1']['beta'], d['H1']['pzsucc'], file=fout_adapt)
                                            except:
                                                continue
                            flk_tasks = []
                            flkadapt_tasks = []
                            snp_info = []
                    ## final SNPs
                    if len(flkadapt_tasks) > 0:
                        print('Computing FLK : %10d'%(i+1))
                        sys.stdout.flush()
                        res = p.map( flkfunc, flk_tasks)
                        pzero = [ x[0] for x in res]
                        flk = [ sum( [v**2 for v in x[1]]) for x in res]
                        pval = [ ( ( x[0] < 0.05) or (x[0] > 0.95)) and np.nan or chi2.sf(x[1], ff.dimension-1) for x in zip(pzero, flk)]
                        ## write output
                        for i_s, info in enumerate(snp_info):
                            print(*info, pzero[i_s], flk[i_s], ff.dimension-1, pval[i_s],file=fout)
                            print(*info, 'raw', *flk_tasks[i_s], file=frq_file)
                            print(*info, 'hrc', *flkadapt_tasks[i_s], file=frq_file)
                        if covdata:
                            print('Computing FLKadapt : %10d'%(i+1))
                            sys.stdout.flush()
                            freqs = np.array(flkadapt_tasks).T
                            for f in covdata['covariables']:
                                myprefix = opts.prefix+'_'+f
                                print('Covariate ', f)
                                sys.stdout.flush()
                                with hapflkadapt.FLKadapt(kinship, covdata['DesignMatrices'][f], prefix = myprefix, ncpu=opts.ncpu, stdize=True) as flkadapt:
                                    flkadapt.correct_lrt = lambda x : flkadapt.df
                                    bflka = flkadapt.bayesFLKadapt( freqs)
                                    flka = flkadapt.FLKadapt( freqs )
                                with open( myprefix + '.flkadapt','a') as fout_adapt:
                                    for i,d in enumerate(flka):
                                        info = snp_info[i]
                                        try:
                                            print( *info, d['H0']['pzero'],d['H0']['pzsucc'], bflka[i],  \
                                                   d['H0']['llik'], d['H1']['llik'], d['lrt'],  d['pvalue'], \
                                                   d['H1']['pzero'], *d['H1']['beta'], d['H1']['pzsucc'], file=fout_adapt)
                                        except:
                                            continue

        
if __name__=='__main__':
    main()
