#!/usr/bin/env python3
import sys
import argparse
import datetime as dt
import numpy as np
from multiprocessing import cpu_count
from hapflk import utils, hapflk, hapflkadapt
from hapflk import InputOutput as IO



def main():
    counter = utils.Stepper()
    
    parser = argparse.ArgumentParser()
    parser.add_argument( '--flkdb', dest='dbfile', help='DB file created by hapflk', default=None)
    parser.add_argument( '-o', '--prefix', dest='prefix', help='prefix for output files', default='hapflkadapt_out')
    parser.add_argument( '-x', '--ncpu', dest = 'ncpu', help='Number of threads ( default = number of CPU)', default = cpu_count(), type = int)
    parser.add_argument( '--std', dest='std', help='Standardize cofactors', default =False, action = "store_true")
    IO.populate_covariate_args( parser)

    if len(sys.argv) < 2:
        parser.print_help()
        sys.exit(0)
    myopts = parser.parse_args( sys.argv[ 1:])

    counter.new('Loading FLK DB file')
    flkmodel = hapflk.HapFLK.from_db_file( myopts.dbfile)
    
    counter.new('Getting covariates information')
    covdata = IO.get_covariates_matrix( myopts.covar_file, myopts.cov, myopts.qcov, flkmodel.pops)
    ##print( covdata)

    for f in covdata['cofactors']:
        counter.new( " Analysing cofactor %s"% f)
        myprefix = myopts.prefix+'_'+f
        niv = covdata["factor_levels"][f]
        beta_names = ['beta_'+x for x in niv[1:]]
        with hapflkadapt.FLKadapt(flkmodel.kinship, covdata['DesignMatrices'][f], prefix = myprefix, ncpu=myopts.ncpu, stdize=myopts.std) as flkadapt:
            flkadapt.correct_lrt = lambda x : flkadapt.df
            counter.new('Single SNP tests')
            counter.new('BayesFLKadapt')
            ##bflka = np.apply_along_axis( flkadapt.bayes_factor,0,flkmodel.frq)
            bflka = flkadapt.bayesFLKadapt( flkmodel.frq )
            counter.new('FLK adapt')
            flka = flkadapt.FLKadapt( flkmodel.frq )
            ##flka  = np.apply_along_axis( flkadapt.fit_loglik_optim_ML,0,flkmodel.frq)
            if flkmodel.kfrq is not None:
                counter.new('hapFLKadapt')
                hflka = [flkadapt.fit_loglik_clust_pz(flkmodel.kfrq[:,:,:,i]) for i in range(flkmodel.nsnp)]
        with open( myprefix + '.flkadapt','w') as fout:
            print('rs','chr','pos','pzero.null', 'bf',  'L0', 'L1', 'LRT', 'pvalue', 'pzero',*beta_names, 'converge',  file=fout)
            for i,d in enumerate(flka):
                rs = flkmodel.sorted_snps[i].name
                chrom,_,pos=flkmodel.carte.position(rs)
                print( rs, chrom, int(pos), d['H0']['pzero'], bflka[i], d['H0']['llik'], d['H1']['llik'], d['lrt'],  d['pvalue'], d['H1']['pzero'], *d['H1']['beta'], d['pzsucc'], file=fout)
        if flkmodel.kfrq is not None:
            with open( myprefix + '.hapflkadapt','w') as fout:
                print('rs', 'chr', 'pos', 'Zh', 'pvalueh', file=fout)
                for i, d in enumerate(flka):
                    rs = flkmodel.sorted_snps[i].name
                    chrom,_,pos = flkmodel.carte.position(rs)
                    print( rs, chrom, int(pos), hflka[i]['Z'], hflka[i]['pvalue'], file=fout)

    for f in covdata['covariables']:
        counter.new( " Analysing covariate %s"% f)
        myprefix = myopts.prefix+'_'+f
        with hapflkadapt.FLKadapt(flkmodel.kinship, covdata['DesignMatrices'][f], prefix = myprefix, ncpu=myopts.ncpu) as flkadapt:
            flkadapt.correct_lrt = lambda x : flkadapt.df
            counter.new('Single SNP tests')
            counter.new('BayesFLKadapt')
            ##bflka = np.apply_along_axis( flkadapt.bayes_factor,0,flkmodel.frq)
            bflka = flkadapt.bayesFLKadapt( flkmodel.frq )
            counter.new('FLK adapt')
            flka = flkadapt.FLKadapt( flkmodel.frq )
            ##flka  = np.apply_along_axis( flkadapt.fit_loglik_optim_ML,0,flkmodel.frq)
            if flkmodel.kfrq is not None:
                counter.new('hapFLKadapt')
                hflka = [flkadapt.fit_loglik_clust_pz(flkmodel.kfrq[:,:,:,i]) for i in range(flkmodel.nsnp)]
        with open( myprefix + '.flkadapt','w') as fout:
            print('rs','chr','pos','pzero.null','bf', 'L0', 'L1', 'LRT', 'pvalue', 'pzero','beta', 'converge', file=fout)
            for i,d in enumerate(flka):
                rs = flkmodel.sorted_snps[i].name
                chrom,_,pos=flkmodel.carte.position(rs)
                print( rs, chrom, int(pos), d['H0']['pzero'], bflka[i], d['H0']['llik'], d['H1']['llik'], d['lrt'],  d['pvalue'], d['H1']['pzero'], *d['H1']['beta'], d['pzsucc'], file=fout)
        if flkmodel.kfrq is not None:
            with open( myprefix + '.hapflkadapt','w') as fout:
                print('rs', 'chr', 'pos', 'Zh', 'pvalueh', file=fout)
                for i, d in enumerate(flka):
                    rs = flkmodel.sorted_snps[i].name
                    chrom,_,pos = flkmodel.carte.position(rs)
                    print( rs, chrom, int(pos), hflka[i]['Z'], hflka[i]['pvalue'], file=fout)
    counter.new('Done')
    counter.end()
                    
if __name__=='__main__':
    main()
