#!/usr/bin/env python3
import sys
import argparse
import random
from functools import partial
from multiprocessing import Pool, cpu_count
import numpy as np
from scipy.stats import chi2,norm
from hapflk import popgen, InputOutput, hapflkadapt


sync_codes=['A','T','C','G','N','DEL']

def get_line(fic, position):
    start_position = position
    while True:
        fic.seek(position)
        symbol = fic.read(1)
        if symbol == '\n' and fic.tell() != start_position + 1:  # get line when reaches \n, and when \n was not the first match
            return fic.readline().rstrip()
        elif position == 0:                                       # get symbol + rest of the line when reaches start of file
            return symbol + fic.readline().rstrip()
        else:
            position -= 1

def get_random_lines(fic, num_of_lines):
    fic.seek(0, 2)
    last_char = fic.tell() - 1  # skipping EOF position with -1
    lines = []
    pos = []
    for _ in range(num_of_lines):
        pos.append(random.randint(0, last_char))
    for position in np.sort(pos):
        lines.append(get_line(fic, position))
    return lines



def parse_sync_line(ligne):
    '''
    Reads in line of sync file and returns allele frequencies predictions 

    Input example: 
    NC_010443.5 112 C  0:0:23:0:0:0    0:4:18:0:0:0    0:8:13:0:0:0  ...

    Output:
    a = numpy.array of floats
    '''
    
    ## split pools
    buf = ligne.split()
    snp_info=buf[:3]
    ref_allele = buf[2].upper()
    ref_position = sync_codes.index(ref_allele)
    nobs = 0.0
    nrefs = 0.0
    nrefs_pool = []
    nobs_pool = []
    for poolobs in buf[3:]:
        n=[float(x) for x in poolobs.split(':')]
        nobs_pool.append( sum(n))
        nrefs_pool.append( n[ ref_position])
    nobs = sum(nobs_pool)
    nrefs = sum(nrefs_pool)

    if nrefs == 0:
        return ( nobs, np.zeros(len(nrefs_pool),dtype=float),np.zeros(len(nrefs_pool),dtype=float))
    if nrefs == nobs:
        return ( nobs, np.ones(len(nrefs_pool),dtype=float),np.ones(len(nrefs_pool), dtype=float))

    pbar = nrefs/nobs

    if pbar < 0.5:
        a_prior = 1
        b_prior = ( 1 - pbar)/pbar
    else:
        a_prior = pbar/( 1 - pbar)
        b_prior = 1

    fpool_pref = []
    fpool_raw = []
    for poolobs in zip( nrefs_pool, nobs_pool):
        num = a_prior + poolobs[0]
        den = num + b_prior + (poolobs[1] - poolobs[0])
        fpool_pref.append( num / den)
        try:
            fpool_raw.append( poolobs[0] / poolobs[1])
        except ZeroDivisionError:
            fpool_raw.append( np.nan )
    return ( nobs, np.array(fpool_pref), np.array(fpool_raw))


def eigen_contrib_diallelic(p,w,un,invF,D,Q):
    '''
    Computes orthogonalized contributions of each component to the FLK
    statistic, using the spectral decomposition of the variance
    covariance matrix V(p) V=cste*F=cste*Q*D*Q.T where Q is the matrix
    of eigen vectors and D the corresponding diagonal matrix with
    eigen values on the diagonal.
    
    Parameters:
    ---------------
    
    p : vector of population allele frequencies at a locus
    
    Return Value:
    -----------------
    a tuple of :
    - p0hat : estimated allele frequency in the ancestral population
    - Z : contributions of each PC to the FLK statistic
    - cste : multiplying constant to the test
    the value of FLK is = C*sum(L^2)
    '''
    p0hat = np.dot(w.T,p)
    cste = ((1 - (1/(np.dot(un,np.dot(invF,un.T)))))/(p0hat*(1-p0hat)))
    Z=np.dot((p-(p0hat*un)),Q)
    Z=np.dot(Z,np.diag(np.sqrt(1/D)))
    Z *= np.sqrt(cste)
    return p0hat,Z,cste

def pop_contrib_diallelic(p, w, un, invF, D, Q):
    p0hat,Z,cste = eigen_contrib_diallelic(p,w,un,invF,D,Q)
    pdev = p-(p0hat*un)
    Zpop = np.dot( pdev, invF)*cste
    ##popcontrib = [ np.sign(x) * np.sqrt( np.abs( x*y)) for x,y in zip(Zpop ,pdev)]
    popcontrib = [x*y for x,y in zip(Zpop ,pdev)]
    signedpopcontrib = [ np.sign(x)*np.sqrt( np.abs(x*y)) for x,y in zip( Zpop, pdev)]
    # for i, z in enumerate(Zpop):
    #     print( '\t',i, Zpop[i],pdev[i],popcontrib[i])
    return p0hat,Z,popcontrib,signedpopcontrib
    
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--sync', dest="filename",help='Genotype input file name (Sync format)')
    parser.add_argument('-p','--prefix',dest='prefix',help='prefix for output files',default='poolflk')
    parser.add_argument('--ncpu', dest='ncpu', help='use multiple cpus', default = cpu_count(), type=int)
    parser.add_argument('--kinship',help='Read population kinship from file',metavar='FILE',default=None)
    # parser.add_argument('-L',dest='kinsnps', help='Number of SNPs to estimate kinship',default=1e5,type=int)
    parser.add_argument('--popnames',help='Read population names from file (one per line)', metavar='FILE',default=None)
    # parser.add_argument('--remove', help='Remove populations (list of indices separated by commas)',default='')
    # parser.add_argument('--minDP', help='Consider SNPs with dp > ( MINDP * npools ) for kinship', default=30, type=int)
    # parser.add_argument('--no-FLK',help='Do not compute FLK, only FLKadapt', dest='flk', default=True, action='store_false')
                        
    #InputOutput.populate_covariate_args(parser)
    if len(sys.argv)<2:
        parser.print_help()
        sys.exit(1)
    opts = parser.parse_args(sys.argv[1:])

    npools = len( open(opts.filename).readline().split()) -3
    mypopnames = [x.split()[0] for x in open(opts.popnames).readlines()]
    kinship = popgen.popKinship_fromFile(opts.kinship,mypopnames)
    
    ### FLK calculations
    ff = popgen.FLK_test(kinship)
    flkfunc = partial( pop_contrib_diallelic, w=ff.w,un = ff.un ,invF = ff.invF, D = ff.D, Q=ff.Q)
    with open(opts.filename) as sync_file:
        with open(opts.prefix+'.popcontrib','w') as fout:
            print( 'chr','pos','refa',*mypopnames, file=fout)
            for i, ligne in enumerate(sync_file):
                dp, freq, rawfreq = parse_sync_line( ligne)
                buf = ligne.split()
                res = flkfunc( rawfreq)
                print ( *buf[:3], *res[3],file=fout)
            
        # with Pool( opts.ncpu) as p:
        #     with open(opts.prefix+'.flk','w') as fout:
        #         with open(opts.prefix+'.frq','w') as frq_file:
        #             print( 'chr', 'pos','refa','DP', 'est', *mypopnames, file=frq_file)
        #             print( 'chr', 'pos', 'refa', 'DP','pzero','FLK', 'df','pval', file=fout)
        #             flk_tasks = []
        #             flkadapt_tasks = []
        #             snp_info = []
        #             batch_size = 100000
        #             for i, ligne in enumerate(sync_file):
        #                 buf = ligne.split()
        #                 dp, freq, rawfreq = parse_sync_line( ligne)
        #                 flkadapt_tasks.append(freq[popidx])
        #                 flk_tasks.append(rawfreq[popidx])
        #                 snp_info.append( buf[:3]+[dp] )
        #                 if (i+1)%batch_size == 0:
        #                     print('Computing FLK : %10d'%(i+1))
        #                     sys.stdout.flush()
        #                     res = p.map( flkfunc, flk_tasks)
        #                     pzero = [ x[0] for x in res]
        #                     flk = [ sum( [v**2 for v in x[1]]) for x in res]
        #                     pval = [ ( ( x[0] < 0.05) or (x[0] > 0.95)) and np.nan or chi2.sf(x[1], ff.dimension-1) for x in zip(pzero, flk)]
        #                     ## write output
        #                     for i_s, info in enumerate(snp_info):
        #                         print(*info, pzero[i_s], flk[i_s], ff.dimension-1, pval[i_s],file=fout)
        #                         print(*info, 'raw',*flk_tasks[i_s], file=frq_file)
        #                         print(*info, 'hrc',*flkadapt_tasks[i_s], file=frq_file)
        #                     if covdata:
        #                         print('Computing FLKadapt : %10d'%(i+1))
        #                         sys.stdout.flush()
        #                         freqs = np.array(flkadapt_tasks).T
        #                         for f in covdata['covariables']:
        #                             myprefix = opts.prefix+'_'+f
        #                             print('Covariate ', f)
        #                             sys.stdout.flush()
        #                             with hapflkadapt.FLKadapt(kinship, covdata['DesignMatrices'][f], prefix = myprefix, ncpu=opts.ncpu, stdize=True) as flkadapt:
        #                                 flkadapt.correct_lrt = lambda x : flkadapt.df
        #                                 bflka = flkadapt.bayesFLKadapt( freqs)
        #                                 flka = flkadapt.FLKadapt( freqs )
        #                             with open( myprefix + '.flkadapt','a') as fout_adapt:
        #                                 for i,d in enumerate(flka):
        #                                     info = snp_info[i]
        #                                     try:
        #                                         print( *info, d['H0']['pzero'],d['H0']['pzsucc'], bflka[i],  \
        #                                                d['H0']['llik'], d['H1']['llik'], d['lrt'],  d['pvalue'], \
        #                                                d['H1']['pzero'], *d['H1']['beta'], d['H1']['pzsucc'], file=fout_adapt)
        #                                     except:
        #                                         continue
        #                     flk_tasks = []
        #                     flkadapt_tasks = []
        #                     snp_info = []
        #             ## final SNPs
        #             if len(flkadapt_tasks) > 0:
        #                 print('Computing FLK : %10d'%(i+1))
        #                 sys.stdout.flush()
        #                 res = p.map( flkfunc, flk_tasks)
        #                 pzero = [ x[0] for x in res]
        #                 flk = [ sum( [v**2 for v in x[1]]) for x in res]
        #                 pval = [ ( ( x[0] < 0.05) or (x[0] > 0.95)) and np.nan or chi2.sf(x[1], ff.dimension-1) for x in zip(pzero, flk)]
        #                 ## write output
        #                 for i_s, info in enumerate(snp_info):
        #                     print(*info, pzero[i_s], flk[i_s], ff.dimension-1, pval[i_s],file=fout)
        #                     print(*info, 'raw', *flk_tasks[i_s], file=frq_file)
        #                     print(*info, 'hrc', *flkadapt_tasks[i_s], file=frq_file)
        #                 if covdata:
        #                     print('Computing FLKadapt : %10d'%(i+1))
        #                     sys.stdout.flush()
        #                     freqs = np.array(flkadapt_tasks).T
        #                     for f in covdata['covariables']:
        #                         myprefix = opts.prefix+'_'+f
        #                         print('Covariate ', f)
        #                         sys.stdout.flush()
        #                         with hapflkadapt.FLKadapt(kinship, covdata['DesignMatrices'][f], prefix = myprefix, ncpu=opts.ncpu, stdize=True) as flkadapt:
        #                             flkadapt.correct_lrt = lambda x : flkadapt.df
        #                             bflka = flkadapt.bayesFLKadapt( freqs)
        #                             flka = flkadapt.FLKadapt( freqs )
        #                         with open( myprefix + '.flkadapt','a') as fout_adapt:
        #                             for i,d in enumerate(flka):
        #                                 info = snp_info[i]
        #                                 try:
        #                                     print( *info, d['H0']['pzero'],d['H0']['pzsucc'], bflka[i],  \
        #                                            d['H0']['llik'], d['H1']['llik'], d['lrt'],  d['pvalue'], \
        #                                            d['H1']['pzero'], *d['H1']['beta'], d['H1']['pzsucc'], file=fout_adapt)
        #                                 except:
        #                                     continue

        
if __name__=='__main__':
    main()
