#!/usr/bin/env python

import numpy as np
import scipy.stats as st
from pickle import load,loads,dump,dumps
import click
import blosc
from statsmodels.stats.multitest import multipletests 
import re

import os
from pathlib import Path
if 'PROSTDIR' in os.environ: prostdir = os.environ['PROSTDIR']
else: prostdir = str(Path.home())+'/.config/prost'

from itertools import groupby
def fasta_iter(fastafile):
    fh = open(fastafile)
    faiter = (x[1] for x in groupby(fh, lambda line: line[0] == ">"))
    for header in faiter:
        header = next(header)[1:].strip()
        seq = "".join(s.strip() for s in next(faiter))
        yield header, seq


def check_seq(seq):
    std = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']
    ambiguous = [ 'X', 'B', 'U', 'Z', 'O']
    aa = std+ambiguous
    for a in seq.upper():
        if a not in aa:
            return False,a
    return True,''
    
def parseName(name):
    if type(name) == tuple: return name
    res = re.search("\|([\w]+)\|(\w+_\w+) (.*) OS=(.*) OX=(\d+) GN=(.*) PE=",name)
    if res is None:
        res = re.search("\|([\w]+)\|(\w+_\w+) (.*) OS=(.*) OX=(\d+) PE=",name)
        if res is None:
            return (name,'','','','','')
        return (res.group(1),res.group(2),res.group(3),res.group(4),res.group(5),'')
    else:
        return (res.group(1),res.group(2),res.group(3),res.group(4),res.group(5),res.group(6))

def annotate(ind,evals,go,goFrq,goDesc):
    spTotalCnt = goFrq['count']
    gores = go[ind]

    #count go term frequencies in the hits
    totalCnt = 0
    goDict = {}
    for r in gores:
        for term in r:
            if term not in goDict: goDict[term] = 1
            else: goDict[term] += 1
            totalCnt += 1
    #dont perform significance test on annotationless proteins, but count them
    goDict.pop('', None)
    if len(goDict) < 1: return []
            
    #perform significance test
    plist = []
    for g,cnt in goDict.items():
        contTable=[[cnt,goFrq[g]],[totalCnt,spTotalCnt]]
        _,p,_,_ = st.chi2_contingency(contTable)
        plist.append([g,p])
    if len(plist) < 1: return []

    #apply multiple p test correction
    corrp = list(multipletests([s[1] for s in plist], method="bonferroni")[1])

    #find p<0.001 
    significant = list()
    for i,p in enumerate(corrp):
        if p < 0.001: 
            #print(plist[i][0],goDesc[plist[i][0]])
            prot_evals = []
            prot_inds = []
            for pind,prot in enumerate(ind):
                if plist[i][0] in go[prot]:
                    #print(prot,evals[pind])
                    prot_evals.append(evals[pind])
                    prot_inds.append(prot)
            prot_pvals =  1 - np.exp(-np.array(prot_evals))
            p2 = st.combine_pvalues(prot_pvals,method='stouffer')[1]
            #apply multiple correction to new pval. 
            #Then multiply this with 10 to get 0.05-> 0.5 then substract this from 1 to get 0.5 confidence for 0.05 pval.
            conf = 1-p2*len(prot_evals)*10
            #print(len(prot_evals),p,p2,p2*len(prot_evals),conf)
            
            #if combined e-values produces p<0.001 then and add the description
            if p2 < 0.05:
                significant.append([plist[i][0],goDesc[plist[i][0]],conf,prot_inds[0],prot_evals[0],len(prot_evals)])

    #sort by the list by increasing p value
    #significant.sort(reverse=False,key=lambda x: x[2])
    
    #sort by the list by decreasing confidence
    significant.sort(reverse=True,key=lambda x: x[2])
    return significant
    
@click.command()
@click.argument('gocsv', type=click.Path(exists=True,file_okay=True,dir_okay=False))
@click.argument('goobo', type=click.Path(exists=True,file_okay=True,dir_okay=False))
@click.argument('prdb', type=click.Path(exists=True,file_okay=True,dir_okay=False))
@click.argument('out', type=click.Path(exists=False,file_okay=True,dir_okay=False))
def mkgo(gocsv, goobo, prdb, out):
    '''mkgo command creates GO databse suitable for PROST.
mkcache command gets go tab file in csv format, GO descriptions in obo format,
and a PROST database to create a go annotations file.

It will create an output file that contains [go,freq,desc] lists.
'''

    print("Read the go csv file",gocsv)
    go = {}
    with open(gocsv,'r') as f:
        for line in f:
            id,golist = line.strip().split(',')
            go[id] = (list(set(golist.replace(' ','').split(';'))))

    print("Read the PROST database",prdb)
    with open(prdb,'rb') as f:
        qnames,qdb = loads(blosc.decompress(f.read()))

    print("Gather GO annotations for proteins in the database")
    godb = np.empty(len(qnames),dtype=object)
    for i,name in enumerate(qnames):
        id = name.split('|')[1]
        godb[i] = go[id]


    print("GO Frequencies: count the occurances of terms")
    frq = {}
    totalCnt = 0
    for l in godb:
        for term in l:
            if term in frq: frq[term] += 1
            else: frq[term] = 1
            totalCnt += 1

    itm = list(frq.items())


    uniqueTermCnt = len(frq)
    print("Total term count:",totalCnt, "Unique term count:",uniqueTermCnt)

    frq['count'] = totalCnt
    frq['uniqueTerm'] = uniqueTermCnt

    print("Prepare GO descriptions from",goobo)
    state = 0
    id = ''
    terms = {}

    with open(goobo,'r') as f:
        for line in f:
            line = line.strip()
            if line.startswith('[Term]'):
                state = 1
                continue
            if state == 1:
                id = line.split(' ')[1]
                state = 2
                continue
            if state == 2:
                if id not in terms:
                    terms[id] = line.split(': ')[1]
                    state = 0
                else:
                    print('error',id,line,'exists in the dictionary',terms[id])


    item = list(terms.items())


    print("Look at swissprot annotations and if they dont have description then add empty one")
    cnt = 0
    for g in frq.keys():
        if g.startswith('GO:'):
            if g not in terms:
                cnt +=1
                terms[g] = ''

    print('Empty description count',cnt)

    print(len(godb),len(qdb),len(qnames),len(frq),len(terms))
    with open(out,'wb') as f:
        dump([godb,frq,terms],f)


@click.command()
@click.option('-t', '--test', is_flag=True, default=False, help='Test random 1000 embeddings')
@click.argument('fasta', type=click.Path(exists=True,file_okay=True,dir_okay=False))
@click.argument('prdb', type=click.Path(exists=True,file_okay=True,dir_okay=False))
@click.argument('out', type=click.Path(exists=False,file_okay=True,dir_okay=False))
def mkcache(test, fasta, prdb, out):
    '''mkcache command gets a fasta file and a PROST database to create a cache file.
Cache files are dictionary which consist of amino acid sequence keys and PROST embedding values.
This command should be run on unparsed PROST databases (no parseUniprotNames)'''
    from random import sample
    from pyprost import quantSeq
    prdbdict = {}
    with open(prdb,'rb') as f:
        qnames,qdb = loads(blosc.decompress(f.read()))
    for i,n in enumerate(qnames):
        prdbdict[n] = qdb[i]

    cache = {}
    seq = {}
    seq2 = []
    for fa in fasta_iter(fasta):
        name = fa[0]
        if name not in prdbdict:
            print('could not found',name,len(fa[1]))
            continue
        #if fa[1] in cache:
        #    print('already exists',fa[0],fa[1])
        cache[fa[1]] = prdbdict[name]
        seq[name] = fa[1]
        seq2.append(fa[1])

    with open(out,'wb') as f:
        dump(cache,f)

    print('PROST db size:',len(qdb),'cache size:',len(cache.keys()),'unique seq size',len(set(seq2)))
    if test:
        print("Testing random 1000 entries by re quantizing them and checking if the cached version is the same")
        for i in sample(range(len(qdb)), 1000):
            name = qnames[i]
            s = seq[name]
            quant = qdb[i]
            q = quantSeq(s)
            if not np.array_equal(cache[s], quant):
                print('DB Cache missmatch!',i,name,s)
            if np.sum(np.abs(cache[s]-q)) > 2:
                print('Cache Quant missmatch!',i,np.sum(np.abs(cache[s]-q)),name,s)


@click.command()
@click.argument('prdb', type=click.Path(exists=True,file_okay=True,dir_okay=False))
@click.argument('out', type=click.Path(exists=False,file_okay=True,dir_okay=False))
def parseUniprotNames(prdb, out):
    '''PROST python package v0.1 parseUniprotNames command.
parseUniprotNames commend gets a PROST database file and parses the names by:
[UniprotID, Name, Type, Organism, OrganismID, Gene]
and saves upated PROST database to out argument.'''
    import re

    with open(prdb,'rb') as f:
        qnames,qdb = loads(blosc.decompress(f.read()))
    names = np.empty(len(qnames),dtype=object)
    for i,name in enumerate(qnames):
        names[i] = parseName(name)

    print(len(names),names[0],names[-1])

    with open(out,'wb') as f:
        f.write(blosc.compress(dumps([names,qdb])))


@click.command()
@click.option('-n', '--no-cache', is_flag=True, default=False, help='Disable embedding caching')
@click.argument('fasta', type=click.Path(exists=True,file_okay=True,dir_okay=False))
@click.argument('out', type=click.Path(exists=False,file_okay=True,dir_okay=False))
def makedb(no_cache, fasta, out):
    '''makedb command creates PROST databases from FASTA files.
makedb command gets a fasta file and creates a PROST database that can be used as querty or taget database in a search.'''
    from pyprost import quantSeq

    cache = {}
    cacheDirty = False
    if not no_cache:
        if os.path.exists(prostdir+'/cache.pkl'):
            with open(prostdir+'/cache.pkl','rb') as f:
                cache = load(f)

    quant = []
    namesd = {}
    ind = 0

    for fa in fasta_iter(fasta):
        name = fa[0]

        l = len(fa[1])
        if l < 5:
            print(name,'discarded, length:',l)
            continue

        status,offchar = check_seq(fa[1])
        if status == False:
            print(name,'contains unknown aa',offchar)
            continue

        if name in namesd:
            print(name,'is already exits!')
            assert np.shape(quant[namesd[name]])[0] == 475
            continue

        namesd[name] = ind
        ind += 1

        if fa[1] in cache:
            quant.append(cache[fa[1]])
        else:
            print(name,'not found in cache. Quantize it.')
            qseq = quantSeq(fa[1])
            quant.append(qseq)
            cache[fa[1]] = qseq
            cacheDirty = True

        assert np.shape(quant[-1])[0] == 475

    names = list(namesd.keys())

    assert len(names) == np.shape(quant)[0]
    print('Total number of sequences embedded in the db:',len(names))

    with open(out,'wb') as f:
        f.write(blosc.compress(dumps([np.array(names),np.array(quant)])))

    if not no_cache:
        if cacheDirty:
            with open(prostdir+'/cache.pkl','wb') as f:
                dump(cache,f)

def _search(thr, gothr, querydb, targetdb, godb, out):
    if godb != None:
        with open(godb,'rb') as f:
            go,goFrq,goDesc = load(f)
    with open(querydb,'rb') as f:
        qnames,qdb = loads(blosc.decompress(f.read()))
    with open(targetdb,'rb') as f:
        tnames,tdb = loads(blosc.decompress(f.read()))
    ldb = len(tdb)
    output = []

    mem = np.zeros((ldb,475),dtype='int8')
    for i,q in enumerate(qdb):
        qname = parseName(qnames[i])[0]
        print(f'Searching for {qname}')
        np.subtract(tdb,q,out=mem)
        np.absolute(mem,out=mem)
        dbdiff = mem.sum(axis=1)
        m=np.median(dbdiff)
        s=st.median_abs_deviation(dbdiff)*1.4826
        zscore = (dbdiff-m)/s
        e = st.norm.cdf(zscore)*ldb
        res = np.where(e < thr)[0]
        sort = np.argsort(e[res])
        res = res[sort]
        dbdiff = dbdiff[res]/2
        evals = e[res]
        names = tnames[res]
        
        
        if godb != None:
            res2 = np.where(e < gothr)[0]
            sort2 = np.argsort(e[res2])
            res2 = res2[sort2]
            for a in annotate(res2,e[res2],go,goFrq,goDesc):
                output.append(f'{qname}\t{a[0]}\t{a[1]}\t{a[2]:.3f}\t{parseName(tnames[a[3]])[0]}\t{a[5]}\t{a[4]:.2e}')
                
        for n,diff,ev in zip(names,dbdiff,evals):
            n = parseName(n)
            output.append(f'{qname}\t{n[0]}\t{n[1]}\t{n[2]}\t{n[3]}\t{diff}\t{ev:.2e}')

    with open(out,'w') as f:
        for o in output:
            f.write(o+'\n')

@click.command()
@click.option('--thr', default=0.05, help='E-value threshold for homolog detection')
@click.argument('querydb', type=click.Path(exists=True,file_okay=True,dir_okay=False))
@click.argument('targetdb', type=click.Path(exists=True,file_okay=True,dir_okay=False))
@click.argument('out', type=click.Path(exists=False,file_okay=True,dir_okay=False))
def search(thr, querydb, targetdb, out):
    '''Search a query database in target database.
This command searches a query database against a target database.
Both databases should be created using makedb command.
Databases can contain one or more sequences.
An e-value threshold can be specified with --thr flag. The default e-value threshold is 0.05'''
    _search(thr,None,querydb,targetdb,None,out)

@click.command()
@click.option('--thr', default=0.05, help='E-value threshold for homolog detection')
@click.option('--gothr', default=0.05, help='E-value threshold for GO annotation')
@click.argument('querydb', type=click.Path(exists=True,file_okay=True,dir_okay=False))
@click.argument('out', type=click.Path(exists=False,file_okay=True,dir_okay=False))
def searchsp(thr,gothr, querydb, out):
    '''Search query database in SwissProt February 2023 database.
This command searches a query database against a SwissProt February 2023 database.
Query database should be created using makedb command.
It can contain one or more sequences.
An e-value threshold can be specified with --thr flag. The default e-value threshold is 0.05.
An seperate GO annotation threshold can be specified with --gothr flag. The default is 0.05.'''
    _search(thr,gothr,querydb,prostdir+'/sp.02.23.parsed.prdb',prostdir+'/sp.02.23.go.pkl',out)

@click.group()
def cli():
    '''PROST python package v0.2.7
Please specify a command.
makedb: creates a PROST database from given fasta file. The fasta file usually contains more than one entry.
search: searches a query database against a target database. Query database can contain one or more sequences embedded using makedb command.
searchsp: searches a query database against SwissProt February 2023 database. Query database can contain one or more sequences embedded using makedb command.'''
    pass

cli.add_command(makedb)
cli.add_command(search)
cli.add_command(searchsp)
cli.add_command(mkgo)
cli.add_command(mkcache)
cli.add_command(parseUniprotNames)

if __name__ == '__main__':
    cli()
