#!/usr/bin/env python

# CoreTracker Copyright (C) 2016  Emmanuel Noutahi
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from Bio.SeqRecord import SeqRecord
from Bio.Seq import Seq
from coretracker.coreutils import CoreFile
from coretracker.coreutils.mtgenes import revmtgenes
import argparse
from collections import defaultdict
import sqlite3
from Bio.Data import CodonTable
from Bio import SeqIO, Entrez
import re
import time
import logging
import operator


def isInt(value):
    """ Verify if a string
    can be casted to an int"""
    try:
        int(value)
        return True
    except ValueError:
        return False


def regexp(expr, item):
    """Regexp for search in sqlite database"""
    reg = re.compile(expr, re.IGNORECASE)
    return reg.search(item) is not None


def taxonIDLookup(taxonID, sleep=20, maxCheck=3):
    """Lookup a taxon ID (an integer) in the NCBI taxonomy.
    Returns (Species_name, (taxonomy_genus, taxonomy_family, etc.))
    Will likely throw 'server errors' until intenal timeout is reached if given anything else."""
    finished = 0
    while finished <= maxCheck:
        try:
            handleDownload = Entrez.efetch(
                db="taxonomy", id=taxonID, retmode="xml")
            resultsDownload = Entrez.read(handleDownload)
            handleDownload.close()
            finished = maxCheck + 1
        except:
            if finished == maxCheck:
                logging.warn("!! Server unreachable. Returning nothing.")
                return None
            else:
                logging.warn("!! Server error on %s - retrying..." %
                             str(taxonID))
                finished += 1
                time.sleep(sleep)

    scientificName = resultsDownload[0]['ScientificName']
    lineage = resultsDownload[0]['Lineage'].split(
        "; ")[1:]  # remove cellular organism
    lineage = ">".join(lineage)
    mitoCode = int(resultsDownload[0].get(
        'MitoGeneticCode', {}).get('MGCId', '1'))
    nucCode = int(resultsDownload[0].get('GeneticCode', {}).get('GCId', '1'))
    return (scientificName, lineage, nucCode, mitoCode)


def db_lookup(conn, spec, complete=True):
    """Check if a name is present in the database of name"""
    result = None
    # it is just more painless to do this species by species
    query = "SELECT organism, seqid FROM genbank WHERE organism REGEXP ?"
    if(complete):
        query += " AND complete_genome = 1"
    cursor = conn.execute(query, [spec])
    # just take the first one
    row = cursor.fetchone()
    if row:
        result = (row[0], row[1])
    else:
        # retry but with source instead of
        cursor = conn.execute(query.replace(
            "organism REGEXP ?", 'source REGEXP ?'), [spec])
        row = cursor.fetchone()
        if row:
            result = (row[0], row[1])

        else:
            logging.warn("Species not found : %s" % spec)
    return result


def check_species_list(specielist, namedb, taxid=False, complete=True, sleep=20, maxCheck=3):
    """ Verify and return the true name of each specie in your specielist """
    curated_spec_list = []
    # convert taxid to scientifique name first
    with sqlite3.connect(namedb) as conn:
        conn.create_function("REGEXP", 2, regexp)
        for spec in specielist:
            if taxid and isInt(spec):
                name = taxonIDLookup(spec, maxCheck, sleep)[0]
                if name:
                    spec = name[0]
            sciename = db_lookup(conn, spec, complete)
            if sciename:
                curated_spec_list.append(sciename)

    return curated_spec_list


def get_poss_event_diff(gene, prot, features, gcode=1):
    """Get correct rna data, based on difference
    between gene and prot and use features to confirm"""

    codontable = CodonTable.unambiguous_dna_by_id[gcode]
    ct_dict = codontable.forward_table

    def get_possible_start(gene, potent_prot):
        # this piece of code find the best start in the gene
        # according to the template prot
        # the gene is expected to be a cds but frameshifting is accepted
        # but not in the first 5 aas
        # check sequence similarity for those 5 first aa
        # until we have a decent match

        max_n = max(len(gene) / 3, 5)
        # check start from 10 first position
        bestmatch = (0, 0)
        for pos in xrange(10):
            matcher = sum([ct_dict[gene[pos + i * 3:pos + (i + 1) * 3].upper()] == potent_prot[i]
                           for i in xrange(max_n)])
            if matcher == max_n:
                bestmatch = (pos, matcher)
                break
            else:
                bestmatch = max(bestmatch, (pos, matcher),
                                key=operator.itemgetter(1))

        return bestmatch[0]

    best_starting_pos = 0
    if prot:
        best_starting_pos = get_possible_start(gene, prot)

    try:
        potential_prot = gene.translate(gcode)

    except:
        pass
    return best_starting_pos, potential_prot


def extract_genome(speclist, genelist, records, revgenes, getprot=False, gcode=1):
    """Extract genome content based on a list of species """
    spec_code_map = {}
    gene2spec = defaultdict(list)
    prot2spec = defaultdict(list)

    def motif_in_seq(motif, slist, product):
        return sum([1 for x in slist if (motif in x.lower())]) > 0 and \
            'hypothetical' in " ".join(product).lower()

    # fetch gene and protein from the database and add them to the list
    for (spec, s_id) in speclist:
        g_found_in_spc = {}
        specaff = spec.replace('.', '').replace('-', '_').replace(' ', '_')
        curr_seq = records[s_id]
        for pos, f in enumerate(curr_seq.features):
            # this should filter hypothetical protein that we do not want
            sgene = None
            meet_condition = f.type.upper() == 'CDS' and 'gene' in f.qualifiers.keys(
            ) and not motif_in_seq('orf', f.qualifiers['gene'], f.qualifiers['product'])
            if meet_condition:
                sgene = f.qualifiers['gene'][0].lower()
                if(sgene in revgenes.keys()):
                    sgene = revgenes[sgene].lower()
                if (sgene in genelist) or (not genelist):
                    seq = None
                    seq = f.extract(curr_seq.seq)
                    if len(seq) % 3 != 0:
                        # print("%s | %d"%(spec, len(seq)%3))
                        try:
                            polyaterm = f.qualifiers['transl_except'][0]
                            pos_range, aa_term = polyaterm.strip(
                                ')').strip('(').split(',')
                            pos_range = pos_range.strip().split(':')[-1]
                            aa_term = aa_term.strip().split(':')[-1]
                            # adding polyA to complete mRNA
                            if 'TERM' in aa_term.upper():
                                # Get the number of A to add
                                n_A = (len(pos_range.split('..')) % 2) + 1
                                seq = seq + Seq('A' * n_A, seq.alphabet)
                                assert len(seq) % 3 == 0
                                logging.info("FIXED : partial termination for the following gene: %s - %s | %d %d A added" %
                                             (sgene, spec, len(seq), n_A))

                        except:
                            logging.warn("Possible frame-shifting in the following gene : %s - %s | %s ==> %d (%d)" %
                                         (sgene, spec, s_id, len(seq), len(seq) % 3))
                    elif 'N' in seq:
                        logging.warn("Sequence with undefined nucleotide : %s - %s | %d" %
                                     (sgene, spec, len(seq)))
                    try:
                        table = int(f.qualifiers['transl_table'][0])
                        spec_code_map[spec] = table
                    except:
                        pass
                    rec = SeqRecord(seq, id=specaff, name=specaff)
                    protseq = Seq(f.qualifiers.get("translation", [])[0])
                    if not protseq:
                        protseq = seq.translate(
                            table=spec_code_map.get(spec, gcode))
                    protrec = SeqRecord(protseq, id=specaff, name=specaff)
                    if g_found_in_spc.get(sgene, 0) < 1:
                        # this is to ensure that the same gene is not added
                        # multiple time
                        gene2spec[sgene].append(rec)
                        prot2spec[sgene].append(protrec)
                    g_found_in_spc[sgene] = 1
    return gene2spec, prot2spec, spec_code_map


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="Extract genome from genbank file into corefile format.")

    sub_parser = parser.add_subparsers(help="Commands", dest='command')
    build_parser = sub_parser.add_parser(
        'build', help="Build database for fast species searching")
    build_parser.add_argument("--namedb", '-db', dest='database',
                              help="Temporary database for fast name searching. Only sqlite database supported")
    build_parser.add_argument("--indexdb", '-ib', dest='indexdb',
                              required=True, help="Index filename")
    build_parser.add_argument("--genbank", "-G", required=True,
                              dest='genbank', help="Your genbank file")

    extract_parser = sub_parser.add_parser(
        'extract', help="Extract genome from your data")
    extract_parser.add_argument("--genbank", "-G", required=True,
                                dest='genbank', help="Your genbank file")
    extract_parser.add_argument("--indexdb", '-ib', dest='indexdb',
                                required=True, help="Index filename")
    extract_parser.add_argument("--namedb", '-db', dest='database', required=True,
                                help="Use build to build a temporary database for fast name searching. Http request will be done if this is not provided")
    extract_parser.add_argument("--taxid", dest='taxid', action='store_true',
                                help="Taxid instead of common name is provided")
    extract_parser.add_argument("--table", dest='gcode', default=1, type=int,
                                help="Genetic code table to use, 1 if not provided")
    extract_parser.add_argument("--sleep", dest='sleep', default=20, type=int,
                                help="Waiting time before next request.")
    extract_parser.add_argument("--maxtry", dest='maxtry', default=3, type=int,
                                help="Maximum number of time a single request should be sent.")
    extract_parser.add_argument("--complete", dest='complete',
                                action='store_true', help="Complete genome only")
    extract_parser.add_argument("--genelist", dest='genelist',
                                help="List of genes. If not provided, all coding sequences in the mapping file will be exported")
    extract_parser.add_argument("--speclist", dest='speclist', required=True,
                                help="A file containing a list of species or a specie name")
    extract_parser.add_argument("--outfile", '-o', dest='output',
                                default='output.core', help="Output file name")
    extract_parser.add_argument("--gsynonym", dest='gsynonym',
                                help="A tsv file containing synonym name for                              each genes. Each line should contain a synonym and the corresponding gene,                                separated by a tabulation")
    extract_parser.add_argument("--email", dest='email',
                                default='mail@example.com', help="Contact email for NCBI's E-utilities")
    args = parser.parse_args()
    if args.command == 'build':
        # in this case, create the databases
        records = SeqIO.index_db(args.indexdb, args.genbank, format="genbank")
        if args.database:
            schema = '''
                create table IF NOT EXISTS genbank (
                    id integer primary key autoincrement not null,
                    accession text not null unique,
                    seqid text not null unique,
                    gi text,
                    complete_genome bit default 0,
                    source text,
                    organism text,
                    lineage text
                );
                '''

            # this will be re-designed if needed
            with sqlite3.connect(args.database) as conn:
                conn.executescript(schema)
                for k, rec in records.items():
                    annot = rec.annotations
                    desc = rec.description
                    gi = None
                    try:
                        gi = annot.get('gi', annot['accessions'])
                        if isinstance(gi, list):
                            gi = gi[0]
                    except:
                        gi = 'null'

                    conn.execute("insert into genbank (accession, seqid, gi, complete_genome, \
                        source, organism, lineage) values (?, ?, ?, ?, ?, ?, ?)", [rec.name, rec.id, gi,
                                                                                   ('complete genome' in desc), annot['source'], annot['organism'], ">".join(annot['taxonomy'])])

                conn.commit()
                print("%s elements inserted in %s" %
                      (conn.total_changes, args.database))
    else:
        # we are actually trying to extract a genome from a list of spec
        Entrez.email = args.email  # set email for EUtilities
        records = SeqIO.index_db(args.indexdb, args.genbank, format="genbank")
        speclist = []
        try:
            for line in open(args.speclist):
                line = line.strip()
                if line and not line.startswith('#'):
                    speclist.append(line)
        except IOError:
            speclist.append(args.speclist)

        speclist = check_species_list(set(speclist), args.database,
                                      args.taxid, args.complete, args.sleep, args.maxtry)

        synonym = {}
        if args.gsynonym:
            with open(args.gsynonym) as SYN:
                for line in SYN:
                    syno, gname = line.strip().split()
                    synonym[syno] = gname
        else:
            synonym = revmtgenes

        genelist = []
        if args.genelist:
            with open(args.genelist) as Glist:
                for line in Glist:
                    if not line.startswith('#'):
                        genelist.append(line.strip().lower())
        gene2spec, prot2spec, spec_code = extract_genome(
            speclist, genelist, records, synonym, gcode=args.gcode)
        CoreFile.write_corefile(gene2spec, args.output)
        CoreFile.write_corefile(prot2spec, args.output + "_prot")

        print("\n---------------------\n+++ Genetic code : \n")
        for spec, code in spec_code.items():
            print("%s ==> %s" % (spec, str(code)))
        print('\n---------------------\n+++ Number of genomes with genes : \n')
        for g, spec in gene2spec.items():
            print("%s ==> %d specs" % (g, len(spec)))
