#!/usr/bin/env python3
# -*- coding: utf-8 -*-


import argparse
import os
import logging
from importlib import resources
import datetime
import multiprocessing

import pandas as pd

from pcalf.core import log , PcalfDB, PcalfSnake
from pcalf.workflow import annotate

logger = logging.getLogger()
logger.setLevel(logging.INFO)
logger.propagate = False

def write_fasta(fastadict,file):
        os.makedirs(os.path.dirname(file),exist_ok=True)
        with open(file,'w') as fh:
            for seqid,sequence in fastadict.items():
                fh.write(">{}\n{}\n".format(seqid,sequence))
        return file


def get_args():
    parser = argparse.ArgumentParser(
        description="""

    pcalf-annotate
    Annotate genomes using pcalf, GTDB-TK, CheckM and ncbi-datasets CLI.
    """,
        formatter_class= argparse.RawTextHelpFormatter
    )
    parser.add_argument('-i', '--input', required=True , type=str , 
            help = "Yaml file containing path to genome, cds_faa and cds_fna, see README.md for details.")  
       
    parser.add_argument('--db', type=None , 
            help = "An sqlite3 database produced by pcalf-workflow. If provided, new datas will be appended to it.")
    
    parser.add_argument('-o', '--outdir', required=True , type=str , 
            default = "./pcalf-annotate-workflow",
            help = "Where datas such as checkm, gtdb-tk or pcalf results will be stored.")

    parser.add_argument("--gtdb", default=None , type=str , 
            help = "path to GTDB database. [None]")
    
    parser.add_argument("--checkm", default=None , type=str , 
            help = "path to checkm datas. [None]")

    parser.add_argument("-z", "--batch-size", default=1500 , type=int , 
            help = "Batch size for CheckM and GTDBTK. [1500]")

    parser.add_argument('-q', '--quick', default=False, action="store_true", 
            help = "If set CheckM and GTDB-TK steps will not be run [False]")

    parser.add_argument('--force', action="store_true", 
           help= "Use this flag if you want to resume jobs from a specific instance. [False]" )
    
    parser.add_argument('--original', action="store_true", 
            help = "Use original HMMs and N-ter files for pcalf. [False]")

    parser.add_argument('--glyx3-msa', dest = 'glyx3_msa', 
                        default = None,
                        help='Path to GlyX3 msa (default: %(default)s). A weighted HMM will be built from it.' )

    parser.add_argument('--gly1-msa', dest = 'gly1_msa', 
                        default = None, 
                        help='Path to GlyZip1 msa (default: %(default)s). A weighted HMM will be built from it.' )

    parser.add_argument('--gly2-msa', dest = 'gly2_msa', 
                        default = None,
                        help='path to GlyZip2 msa (default: %(default)s). A weighted HMM will be built from it.')

    parser.add_argument('--gly3-msa', dest = 'gly3_msa', 
                        default = None,
                        help='path to GlyZip3 msa (default: %(default)s). A weighted HMM will be built from it.')    

    parser.add_argument('--nter', dest = 'nter', 
                        default = None,
                        help='path to N-ter file (default: %(default)s).')    

    parser.add_argument("--debug", action="store_true")

    parser.add_argument('--snakargs', type=str, 
              default="--printshellcmds -j{} --use-conda".format(multiprocessing.cpu_count()),
              help='Snakemake arguments.')

    args = parser.parse_args()
    return args


def main():
    # GET ARGS
    args = get_args()
    outdir = os.path.abspath(args.outdir)
    # SETUP LOGGER
    level = logging.INFO
    if args.debug:
        level = logging.DEBUG
        logger.setLevel(level)
    
    console = logging.StreamHandler()
    console.setLevel(level)
    console.setFormatter(log.CustomFormatter())
    logger.addHandler(
        console
    )    
    
    if os.path.exists(outdir) and not args.force:
        logger.error("{} output directory already exists.".format(outdir))
        exit(-1)
    else:
        os.makedirs(outdir , exist_ok=True)     

    logfile = os.path.join(outdir,"pcalf.annotate.log")
    fhandler = logging.FileHandler(logfile)
    fhandler.setLevel(level)
    logger.addHandler(
        fhandler
    )       
    logger.info("PCALF-WORKFLOW")
    logger.debug("DEBUG")
    logger.info("Output directory :  {}".format(outdir)) 


     # Load DB if provided  else init one.
    dbfile = os.path.join(outdir,"pcalf.db")
    if os.path.exists(dbfile) and args.force:
        os.remove(dbfile)    
    
    logger.debug("DB file is : {}".format(dbfile))
    db = PcalfDB.PcalfDB(dbfile)    
    glyx3 , gly1 , gly2, gly3 , nter = args.glyx3_msa, args.gly1_msa, args.gly2_msa, args.gly3_msa , args.nter

    if args.db: #db file provided            
        logger.info("MSAs from {} will be used.".format(args.db))  
        externaldb = PcalfDB.PcalfDB(args.db)    

        if not externaldb.is_schema_valid():
            logger.error("Current database schema is different from the schema of the database you provide.")
            exit(-1)

        # write MSAs to output directory:
        pcalfdatas = os.path.join(outdir,"pcalf_input_datas")
        gly1  = write_fasta(externaldb.to_msa("gly1") , os.path.join(pcalfdatas,"gly1.msa.fasta" ))
        gly2  = write_fasta(externaldb.to_msa("gly2") , os.path.join(pcalfdatas,"gly2.msa.fasta" ))
        gly3  = write_fasta(externaldb.to_msa("gly3") , os.path.join(pcalfdatas,"gly3.msa.fasta" ))
        glyx3 = write_fasta(externaldb.to_msa("glyx3"), os.path.join(pcalfdatas,"glyx3.msa.fasta"))
        nter = os.path.join(pcalfdatas,"nter.tsv")
        nter_df = externaldb.to_df("nterdb")        
        nter_df.to_csv(nter,header=False,index=False,sep="\t")


    if args.original:
        glyx3 , gly1 , gly2, gly3 , nter = None, None, None, None, None
    

    # MAKE CONFIG FILE    
    # Update Snakefile configuration with CLI values and run workflow.
    annotate_module_path = resources.files(annotate)
    annotate_module = PcalfSnake.Snakemake(annotate_module_path)
    annotate_module.config["global"]["skip_genome_workflow"] = args.quick
    annotate_module.config["global"]["input"] = args.input
    annotate_module.config["global"]["res_dir"] = outdir

    if not args.quick:
        if not args.checkm or not args.gtdb or not os.path.isdir(args.checkm) or not os.path.isdir(args.gtdb):
            logger.error("GTDB and checkm datas should be specified. If you want to skip this part of the analysis use the --quick flag.")
            exit(-1)

    annotate_module.config["config-genomes"]["CheckM_data"] = args.checkm
    annotate_module.config["config-genomes"]["GTDB"] =  args.gtdb
    annotate_module.config["config-genomes"]["batch_size"] =  args.batch_size
    
    annotate_module.config["config-ccya"]["glyx3_msa"] = glyx3
    annotate_module.config["config-ccya"]["gly1_msa"]  = gly1
    annotate_module.config["config-ccya"]["gly2_msa"]  = gly2
    annotate_module.config["config-ccya"]["gly3_msa"]  = gly3
    annotate_module.config["config-ccya"]["nterdb"]    = nter

    configfile = os.path.join(outdir, "config.yaml")
    annotate_module.dump_config(configfile)
    snakargs = args.snakargs.split()
    if '--configfile' not in snakargs:
        snakargs.extend(["--configfile" , configfile])
    
    annotate_module.run(snakargs)    

    ## EXPECTED OUTPUT FILES
    files = [
            (os.path.join(outdir,"pcalf","pcalf.summary.tsv") , "summary" , "sequence_accession" ),
            (os.path.join(outdir,"pcalf","ccyA.summary.tsv") ,  "ccya" , "sequence_id" ),
            (os.path.join(outdir,"pcalf","pcalf.features.tsv") , "features" , "sequence_id" ),
            (os.path.join(outdir,"pcalf","pcalf.hits.tsv") , "hits" , "sequence_id" ),
            (os.path.join(outdir,"ncbi-datas","ncbi_metadatas.tsv") , "genomes" , "Accession" ),
            (os.path.join(outdir,"checkm-res","tables","checkM_statistics_full.tsv") , "checkm" , "Bin Id" ),
            (os.path.join(outdir,"gtdbtk-res","gtdbtk.ar53.bac120.summary.clean.tsv") , "gtdbtk" , "user_genome" ),
            (os.path.join(outdir,"pcalf","MSA","Gly1.msa.tsv") , "gly1"   , "sequence_id" ),
            (os.path.join(outdir,"pcalf","MSA","Gly2.msa.tsv") , "gly2"   , "sequence_id" ),
            (os.path.join(outdir,"pcalf","MSA","Gly3.msa.tsv") , "gly3"   , "sequence_id" ),
            (os.path.join(outdir,"pcalf","MSA","Glyx3.msa.tsv") , "glyx3" , "sequence_id" ),
            (os.path.join(outdir,"pcalf","nter.tsv"), "nterdb", "sequence_id")
        ]

    # CREATE AND FEED DB        
    for file, table, pk in files:
        if os.path.exists(file):
            logger.info("{} stored as {} in {}".format(
                os.path.basename(file),
                table,
                os.path.basename(dbfile)))
            df = pd.read_csv(file,sep="\t",header=0)
            db.feed_db(df, table, pk)

    # CREATE A NEW TABLE WITH 
    gids = db.get_col_values("genomes","Accession")
    df = pd.DataFrame([[_,datetime.date.today()] for _ in gids],columns = ["Accession","Date"])
    db.feed_db(df,"harley","Accession")

    # Update DB if necessary
    if args.db:
        tmp = externaldb.to_df("harley")            
        db.feed_db(tmp,"harley","Accession")
        for _, table, pk in files:
            tmp = externaldb.to_df(table)            
            db.feed_db(tmp,table,pk)

    # check that all sequences in MSA table have the same length !
    for _,table,pk in files:
        if table in ["gly1","gly2","gly3","glyx3"]:
            l = db.get_col_values(table,"sequence")                  
            for aln in l:
                if len(aln) != len(l[0]):
                    logger.warning("Something went wrong with MSA table {}".format(
                        table))



if __name__ == '__main__':
    main()