#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import argparse
import gzip
import os
import sys
import shutil
import logging
import multiprocessing
from importlib import resources
import pandas as pd


import pcalf.datas as datas
from pcalf.core import search,bioseq,biohmm,log


DATASDIR = os.path.join(resources.files(datas))

def get_args():
    parser = argparse.ArgumentParser(
        description="""                        
    pcalf 
    Search calcyanin protein in one or multiple files containing CDS.
    """,
        formatter_class= argparse.RawTextHelpFormatter
    )
    parser.add_argument('-i','--input', dest='input', 
                        type = str, default=None,
                        help="Either a fasta file [gzip supported] or a tabular file with one file \"designation\tfile path\" per line.")
                                  
    parser.add_argument('-o', dest='res_dir', type=str, required=True,
                        help='Output directory. pcalf will ouput several files including updated HMMs and MSAs, a feature table and a summary.')   

    parser.add_argument('--max-iteration', default=4,
                        help="Max iteration to be done if iterative-search enable. (default: %(default)s) ")                        

    parser.add_argument('--iterative-update', action="store_true",
                        help="If set, HMM profiles will be updated by aligning one sequence by one sequence. ")

    parser.add_argument('--glyx3-msa', dest = 'glyx3_msa', 
                        default = DATASDIR + "/GlyX3.msa.fa", 
                        help='Path to GlyX3 msa (default: %(default)s). A weighted HMM will be built from it.')                                     

    parser.add_argument('-Z', dest='Z', type=int, default=10000,
                        help='The effective number of comparisons done, for E-value calculation.\
                            Leave as None to auto-detect by counting the number of sequences queried.\
                            (default: %(default)s).') 

    parser.add_argument('--domZ', dest='domZ', type=int, default=10000,
                        help='The number of significant sequences found, for domain E-value calculation.\
                            Leave as None to auto-detect by counting the number of sequences reported. \
                            (default: %(default)s).')     

    parser.add_argument('--glyx3-coverage', dest='glyx3_coverage_threshold', 
                        type=float,default=None,
                        help="Minimal coverage (default: auto)." )
    
    parser.add_argument('--glyx3-evalue', dest='glyx3_evalue_threshold', 
                        type=float,default=None,
                        help="Glyx3 E-value when considering Glyx3 Hits (default: auto).")
                        
    parser.add_argument('--gly1-msa', dest = 'gly1_msa', 
                        default = DATASDIR + "/Gly1.msa.fa", 
                        help='Path to GlyZip1 msa (default: %(default)s). A weighted HMM will be built from it.' )

    parser.add_argument('--gly2-msa', dest = 'gly2_msa', 
                        default = DATASDIR + "/Gly2.msa.fa", 
                        help='path to GlyZip2 msa (default: %(default)s). A weighted HMM will be built from it.')

    parser.add_argument('--gly3-msa', dest = 'gly3_msa', 
                        default = DATASDIR + "/Gly3.msa.fa", 
                        help='path to GlyZip3 msa (default: %(default)s). A weighted HMM will be built from it.')                        

    parser.add_argument('--glyzip-E-value', dest='glyzip_e_evalue', 
                        type=float,default=None,
                        help = "Glyzip E-evalue when considering Gly(1|2|3) Hits (default: auto)." )
    
    parser.add_argument('--glyzip-coverage', dest='glyzip_coverage', 
                        type=float,default=None,
                        help = "Glyzip coverage threshold when considering Gly(1|2|3) Hits (default: None)." )
    
    parser.add_argument('--nterdb', dest='nterdb', 
                        default = DATASDIR + "/nterdb.ref.tsv", 
                        help='Path to nterdb tabular file (default: %(default)s). The file must have three column: N-ter types, N-ter accessions and N-ter amino acid sequences.')

    parser.add_argument('--nter-coverage', dest='nter_coverage', 
                        type=float,default=0.80,
                        help="Nter minimal coverage (default: %(default)s).")

    parser.add_argument('--nter-evalue', dest='nter_evalue', 
                        type=float,default=1e-07,
                        help="Nter maximal E-value threshold (default: %(default)s).")
    
    parser.add_argument('--log', default = None , type=str)

    parser.add_argument('-q', '--quiet', action="store_true" , help="Silent stdout logging")

    parser.add_argument('--debug', action="store_true")

    parser.add_argument('--force', action = "store_true" ,help = "overwrite output directory")

    parser.add_argument('--threads', type=int, default = multiprocessing.cpu_count(),
                        help="(default: %(default)s)")
        
    args = parser.parse_args()
    
    return args


def main():
    args = get_args()
    logger = logging.getLogger()
    # Setup Logger and Handlers
    level = logging.INFO
    if args.debug:
        level = logging.DEBUG
    logger.setLevel(level)
    if not args.quiet:
        console = logging.StreamHandler()        
        console.setFormatter(log.CustomFormatter())
        console.setLevel(level)
        logger.addHandler(
            console
        )
    if args.log:
        os.makedirs(os.path.dirname(os.path.abspath(args.log)),exist_ok=True)
        fhandler = logging.FileHandler(args.log)
        fhandler.setLevel(level)
        fhandler.setFormatter(log.CustomFormatter())
        logger.addHandler(
            fhandler
        )
    
    logger.info("PCALF")
    logger.debug("DEBUG")

    if not shutil.which("blastp"):
        logger.error("blast not found, please, considere installing it using conda install -c bioconda blast.")
        exit(-1)
    else:
        logger.debug("blastp found : {}".format(shutil.which("blastp")))
    
    
    res_dir = os.path.abspath(args.res_dir)
    if os.path.exists(res_dir):
        if args.force:
            logger.debug("{} directory have been removed because of the --force flag.".format(res_dir))
            shutil.rmtree(res_dir)            
        else:
            logger.error("Output directory already exist , Bye !")
            exit(-1)
    
    fastas = []
    names = []
    if not args.input.endswith(".gz"): 
        with open(args.input,'r') as stream:   
            for line in stream.readlines():
                # Check if not a fasta file
                if not line.startswith(">"):
                    n,f = line.strip().split()
                    fastas.append(f)
                    names.append(n)
                else: # Fasta file
                    fastas.append(args.input)
                    names.append(os.path.basename(args.input))
                    break
    else: 
        with gzip.open(args.input ,'rt') as stream:
            for line in stream.readlines():                
                if not line.startswith(">"): # Not a fasta file
                    logger.error("{} is not a fasta file.".format(args.input))
                else: # fasta file 
                    fastas.append(args.input)
                    names.append(os.path.basename(args.input))
                    break


    logger.info("Init HMMs from MSAs.")    
    glyx3 = biohmm.Hmm("Glyx3",args.glyx3_msa)
    gly1 = biohmm.Hmm("Gly1",args.gly1_msa)
    gly2 = biohmm.Hmm("Gly2",args.gly2_msa)
    gly3 = biohmm.Hmm("Gly3",args.gly3_msa)
    
    logger.info("Increase Glycine weight.")
    glyx3.hmm = search.increase_glycine_weight( glyx3,0.2 )
    gly1.hmm  = search.increase_glycine_weight( gly1,0.2 )
    gly2.hmm  = search.increase_glycine_weight( gly2,0.2 )
    gly3.hmm  = search.increase_glycine_weight( gly3,0.2 )

    logger.info("Init N-Ter DB.")
    nterdb = search.parse_nterdb(args.nterdb)
    
    # Initital number of sequences.
    _glyx3_nseq = glyx3.hmm.nseq
    _gly1_nseq = gly1.hmm.nseq
    _gly2_nseq = gly2.hmm.nseq
    _gly3_nseq = gly3.hmm.nseq



    logger.info("Start search.")
    calseq , u_glyx3 , u_gly1, u_gly2 , u_gly3 , u_nter = search.run_pcalf(
        fastas,
        names,
        glyx3,
        gly1,
        gly2,
        gly3,
        nterdb, 
        Z = args.Z,       
        domZ = args.domZ,
        glyx3_evalue_threshold=args.glyx3_evalue_threshold,  
        glyx3_coverage_threshold=args.glyx3_coverage_threshold,  
        glyzip_evalue_threshold= args.glyzip_e_evalue,
        glyzip_coverage_threshold = args.glyzip_coverage,
        nter_coverage_threshold = args.nter_coverage, 
        nter_evalue_threshold  =  args.nter_evalue,
        is_update_iterative=args.iterative_update,        
        max_iteration=args.max_iteration,
        res_dir = res_dir
    )
   
    logger.info("The search is over.")

    logger.info("# N_seqs within Glyx3 HMM : {} [+{}]".format(
        u_glyx3.hmm.nseq, u_glyx3.hmm.nseq -  _glyx3_nseq))
    logger.info("# N_seqs within Gly1 HMM : {} [+{}]".format(
        u_gly1.hmm.nseq, u_gly1.hmm.nseq - _gly1_nseq))
    logger.info("# N_seqs within Gly2 HMM : {} [+{}]".format(
        u_gly2.hmm.nseq, u_gly2.hmm.nseq - _gly2_nseq))
    logger.info("# N_seqs within Gly3 HMM : {} [+{}]".format(
        u_gly3.hmm.nseq, u_gly3.hmm.nseq - _gly3_nseq))
    
    ##

   

if __name__ == "__main__":
    sys.exit(main())
