#!/usr/bin/env python

# CoreTracker Copyright (C) 2016  Emmanuel Noutahi
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from sklearn.decomposition import PCA
from sklearn.cluster import DBSCAN, MeanShift, estimate_bandwidth
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import StandardScaler
from collections import defaultdict
from Bio.Data import CodonTable
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as mplcm
import matplotlib.colors as mc
import argparse
import json
import itertools
import re
import pandas as pd

SEABORN = False
NETWORK = False
FEPS = np.finfo(float).eps
try:
    import seaborn as sns
    sns.set(context="paper", palette="deep")
    sns.set(style="ticks")
    SEABORN = True
except ImportError as e:
    # quiet fallback to matplotlib
    pass


try:
    import networkx as nx
    NETWORK = True
except ImportError as e:
    print('networkx not found! Text placement might not be optimal')
    pass


codon_list = ["".join(i) for i in itertools.product('ATGC', repeat=3)]
protein_letters = "ACDEFGHIKLMNPQRSTVWY"
TOTAL_CODON = 64


def parse_json_file(jfile):
    with open(jfile) as INDATA:
        data = json.load(INDATA)
    return data


def kernel_estimate(dtlen, labels):
    kde = KernelDensity(kernel='gaussian').fit(dtlen)


def print_data_to_file(data, labels, outfile):
    out = outfile.split('.')[0]
    outfile = out + ".txt"
    with open(outfile, 'w') as OUT:
        header = ['Species'] + codon_list
        OUT.write("\t".join(header) + "\n")
        d = data.shape[0]
        for i in xrange(d):
            line = "\t".join([labels[i]] + [str(int(x))
                                            for x in data[i, :]]) + "\n"
            OUT.write(line)


def get_representation(codons, scale=False, outputfile=None, speclist=[]):
    labels = list(codons.keys())
    if speclist:
        labels = list(set(labels) & set(speclist))
    data = np.zeros((len(labels), 64), dtype=np.float)
    for (i, spec) in enumerate(labels):
        data[i, :] = [codons[spec].get(x, 0) for x in codon_list]
    # discard columns where we have zeros
    if outputfile:
        print_data_to_file(data, labels, outputfile)

    cols = np.where(np.sum(data, axis=0) == 0)[0]
    global TOTAL_CODON
    TOTAL_CODON -= len([codon_list[x] for x in cols])
    print "The following codon are discarded : ", ",".join([codon_list[x] for x in cols])
    data = np.delete(data, cols, axis=1)
    data_len = np.sum(data, axis=0)
    data = data / data_len
    if scale:
        data = StandardScaler().fit_transform(data)
    return data, labels, data_len


def doPCA(data, n=2):
    n = min(n, data.shape[1], data.shape[0])
    pca = PCA(n_components=2)
    return pca.fit_transform(data)


def cluster(data, algo, params={}):

    print_estimate = False
    if algo.upper() == 'DBSCAN':
        algo = DBSCAN(min_samples=3, **params)
    else:
        print_estimate = True
        bandwidth = estimate_bandwidth(data, **params)
        algo = MeanShift(bandwidth=bandwidth, bin_seeding=True)
    labels = algo.fit_predict(data)
    n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
    if print_estimate:
        print('Estimated number of clusters: %d' % n_clusters_)
    #print("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(data, labels))
    return labels


def repel_labels(ax, labels, data, k):
    # See Stack overflow answer from
    # https://goo.gl/XdtaMx
    G = nx.DiGraph()
    data_nodes = []
    init_pos = {}
    for xyi, label in zip(data, labels):
        data_str = 'data_{0}'.format(label)
        G.add_node(data_str)
        G.add_node(label)
        G.add_edge(label, data_str)
        data_nodes.append(data_str)
        init_pos[label] = init_pos[data_str] = (xyi[0], xyi[1])

    pos = nx.spring_layout(G, pos=init_pos, fixed=data_nodes, k=k)

    # undo spring_layout's rescaling
    pos_after = np.vstack([pos[d] for d in data_nodes])
    pos_before = np.vstack([init_pos[d] for d in data_nodes])
    scl, shift_x = np.polyfit(pos_after[:, 0], pos_before[:, 0], 1)
    scl, shift_y = np.polyfit(pos_after[:, 1], pos_before[:, 1], 1)
    shift = np.array([shift_x, shift_y])
    for key, val in pos.items():
        pos[key] = (val * scl) + shift

    for label, data_str in G.edges():
        ax.annotate(label,
                    xy=pos[data_str], xycoords='data',
                    xytext=pos[label], textcoords='data',
                    arrowprops=dict(arrowstyle="->",
                                    connectionstyle="arc3"), )
    # expand limits
    all_pos = np.vstack(pos.values())
    x_span, y_span = np.ptp(all_pos, axis=0)
    mins = np.min(all_pos - x_span * 0.15, 0)
    maxs = np.max(all_pos + y_span * 0.15, 0)
    ax.set_xlim([mins[0], maxs[0]])
    ax.set_ylim([mins[1], maxs[1]])


def plot_result(data, labels, output="output.png", algo='MeanShift', params={}, kk=0.005):
    clusterlab = cluster(data, algo, params=params)
    unique_labels = set(clusterlab)
    cm = plt.get_cmap('Spectral')
    colors = [cm(1. * i / len(unique_labels))
              for i in range(len(unique_labels))]
    if SEABORN:
        colors = sns.color_palette("husl", len(unique_labels))

    fig = plt.figure()
    plt.clf()
    ax = fig.add_subplot(111)
    ax.tick_params(axis='both', which='both', top='off', right='off')
    plt.subplots_adjust(bottom=0.1)

    for k, col in zip(unique_labels, colors):
        if k == -1:
            # Black used for noise.
            col = 'k'
        class_member_mask = (clusterlab == k)
        xy = data[class_member_mask, :2]
        ax.scatter(xy[:, 0], xy[:, 1], marker='o', c=col, s=300)

    if NETWORK:
        repel_labels(ax, labels, data, kk)
    else:
        for l, x, y in zip(labels, data[:, 0], data[:, 1]):
            ax.annotate(l, xy=(x, y), xytext=(-12, 15), textcoords='offset points',
                        clip_on=False, multialignment='center', rasterized=True,
                        size='small', ha='right', va='center', arrowprops=dict(arrowstyle='->', connectionstyle="arc3"))

    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    plt.title('PCA and %s clustering of genomes according to codon usage' %
              algo, y=1.06)
    # plt.show()
    plt.savefig(output, bbox_inches='tight')


def compute_KL(data, avg):
    # in theory, all nan should be removed with this
    return data * np.log((data + FEPS) / (avg + FEPS))


def plot_aa_usage(data, codontable, labels, outfile, aadiscard="", rea_table={}, kl=False):
    cm = plt.get_cmap('gist_rainbow')
    #cNorm = mc.Normalize(vmin=0, vmax=len(labels) - 1)
    #scalarMap = mplcm.ScalarMappable(norm=cNorm, cmap=cm)
    aa_list = [aa for aa in protein_letters if aa not in aadiscard]
    colors = [cm(1. * i / len(aa_list)) for i in range(len(aa_list))]
    y = []
    avg_y_dict = defaultdict(int)
    aa_total_all = 0.0
    for (i, spec) in enumerate(labels):
        specdata = defaultdict(int)
        spec_table = rea_table.get(spec, {})
        for (cod, l) in data[spec].items():
            if cod != '---':
                cur_aa = spec_table.get(
                    cod, codontable.forward_table.get(cod, '*'))
                if cur_aa != '*':
                    specdata[cur_aa.upper()] += l
                    avg_y_dict[cur_aa.upper()] += l
        aa_total = sum([specdata.get(aa, 0) for aa in aa_list])
        aa_total_all += aa_total
        y.append([specdata.get(aa, 0) * 1.0 /
                  aa_total for aa in aa_list])
    avg_y = np.asarray(
        [avg_y_dict.get(aa, 0) * 1.0 / aa_total_all for aa in aa_list])
    y = np.asarray(y)
    ybl = 'AA Freq'
    if kl:
        y = compute_KL(y, avg_y)
        ybl = 'KL divergence'

    plt.clf()
    if SEABORN:
        sns.set_palette("husl", len(aa_list))
        df = pd.DataFrame(y, index=labels, columns=[x for x in aa_list])
        df = df.stack().reset_index().rename(
            columns={'level_0': 'species', 'level_1': 'AA', 0: ybl})
        mk = ['*', 'o', 'd'] * (-(-len(aa_list) // 3))
        sw_ax = sns.factorplot(x='species', y=ybl, aspect=4, legend=False,
                               data=df, hue='AA', size=3, ci=None, markers=mk, scale=0.6)

        if not kl:
            sw_ax.set(ylim=(0, None))
        sw_ax.despine(offset=3, trim=True)
        for ax in sw_ax.axes.flat:
            for lax in ax.get_xticklabels():
                lax.set_rotation(80)
        plt.legend(loc='upper left', bbox_to_anchor=(
            1.005, 1), ncol=2, fancybox=True, shadow=True)
        sw_ax.savefig(outfile, bbox_inches="tight")

    else:
        LStyle = ['solid', 'dashed', 'dotted']
        LStyle = LStyle[:len(colors)]
        fig, ax = plt.subplots(1, 1)
        for i, aa in enumerate(aa_list):
            ax.plot(y[:, i], color=colors[i],
                    linestyle=LStyle[i % len(LStyle)], label=aa)

        ax.tick_params(axis='both', which='both', top='off', right='off')
        for lax in ax.get_xticklabels():
            lax.set_rotation(80)
        plt.legend(loc='upper left',
                   bbox_to_anchor=(1.01, 1), ncol=2, fancybox=True, shadow=True)
        plt.xticks(range(0, len(labels)), labels)
        plt.subplots_adjust(bottom=0.1)
        plt.margins(0.1)
        plt.xlabel('species')
        plt.ylabel(ybl)
        plt.title('A.A. frequency in each species', y=1.06)
        plt.savefig(outfile, bbox_inches="tight")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Plot codon or amino acid usage in genome')
    parser.add_argument('--reafile', '-i', dest="reafile", required=True,
                        help="Json rea file")
    parser.add_argument('--outfile', '-o', dest="outfile",
                        default="output", help="Outfile file")
    parser.add_argument('--scale', action='store_true',
                        dest="scale", help="Scale data")
    parser.add_argument('--repel_k', type=float, default=0.005, help="Optimal distance between nodes for networkx spring layout.\
                        Experiments with this value for the best label placements")
    parser.add_argument('--plot_aa_usage', action='store_true',
                        dest="aausage", help="Plot aa usage (frequency)")
    parser.add_argument('--csv', action='store_true', dest="csv",
                        help="Export codon count in csv format")
    parser.add_argument('--KL', action='store_true', dest="KL",
                        help="Use KL divergence instead of frequency")
    parser.add_argument('--speclist', '-s', dest="speclist",
                        help="Only use the provided list of species.")
    parser.add_argument('--reatable', '-r', dest="reatable", help="Use a specific table for the genome listed in this files.\
                        This file should be in a format similar to Coretracker's output (predictions.txt).")
    parser.add_argument('--aa_discard', '--aar', dest="aadiscard", nargs='+',
                        help="Discard the following aa from the aausage plot")
    parser.add_argument('--gcode', type=int, dest="defcode", default=4,
                        help="Default genetic code to use")
    parser.add_argument('--pca_component', type=int, dest="pca",
                        default=2, help="Number of component for pca")
    parser.add_argument('--clustering_algo', '--algo', choices=('MeanShift', 'DBSCAN'),
                        default='MeanShift', dest="algo", help="Clustering algorithm to use, default: MeanShift")
    parser.add_argument('--bandwidth_quantile', dest="ebquant", type=float,
                        default=0.3, help="Quantile value for bandwidth estimation")

    args = parser.parse_args()
    data = parse_json_file(args.reafile)
    speclist = []
    new_table = defaultdict(dict)
    if args.speclist:
        try:
            for line in open(args.speclist):
                line = line.strip()
                if line and not line.startswith('#'):
                    speclist.append(line)
        except IOError:
            speclist = re.split(';|,| |-', args.speclist)
            speclist = [x.strip() for x in speclist]
    if args.reatable:
        curr_cod = ""
        dest_aa = ""
        pattern = re.compile("^[A-Z]{3,}\s?\([A-Z\*]\,\s[A-Z\*]\)")
        with open(args.reatable) as RTable:
            for line in RTable:
                line = line.strip()
                if line and not line.startswith('#'):
                    if pattern.match(line):
                        curr_cod = line.split('(')[0].strip().replace('U', 'T')
                        dest_aa = line.split(',')[-1].strip(' )')
                    else:
                        spec = line.strip().split()[0].strip()
                        if curr_cod and dest_aa:
                            new_table[spec][curr_cod] = dest_aa
    dt, labels, _ = get_representation(
        data['codons'], args.scale, (args.outfile if args.csv else None), speclist)
    pca_data = doPCA(dt, args.pca)
    outfile, fmt = args.outfile, 'svg'
    if '.' in outfile:
        outfile, fmt = args.outfile.split('.')

    if args.algo == 'DBSCAN':
        params = {}
    else:
        params = {'quantile': args.ebquant}
    plot_result(pca_data, labels, outfile + "." +
                fmt, algo=args.algo, params=params)

    if args.aadiscard:
        args.aadiscard = set(itertools.chain.from_iterable(args.aadiscard))
    if args.aausage:
        codontable = CodonTable.unambiguous_dna_by_id[args.defcode]
        outfile = outfile + '_aa_usage'
        plot_aa_usage(data['codons'], codontable, labels,
                      outfile + "." + fmt, aadiscard=args.aadiscard, rea_table=new_table, kl=args.KL)

    # TODO: gene length distribution outlier and gene number ==> better in the
    # web code
