#!/usr/bin/env python

# CoreTracker Copyright (C) 2016  Emmanuel Noutahi
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from sklearn.decomposition import PCA
from sklearn.cluster import DBSCAN, MeanShift, estimate_bandwidth
from sklearn.neighbors import KernelDensity
import numpy as np
from sklearn import metrics
import matplotlib.pyplot as plt
import matplotlib.cm as mplcm
import matplotlib.colors as mc
import matplotlib.patches as mpatches
import argparse
import json
from sklearn.preprocessing import StandardScaler
import itertools
from collections import defaultdict
from Bio.Data import CodonTable
import re
import pandas as pd

SEABORN = False
try:
    import seaborn as sns
    sns.set(context="paper", palette="deep")
    sns.set_style("white")
    SEABORN = True
except ImportError as e:
    # quiet fallback to matplotlib
    pass

codon_list = ["".join(i) for i in itertools.product('ATGC', repeat=3)]
protein_letters = "ACDEFGHIKLMNPQRSTVWY"
TOTAL_CODON = 64


def parse_json_file(jfile):
    with open(jfile) as INDATA:
        data = json.load(INDATA)
    return data


def kernel_estimate(dtlen, labels):
    kde = KernelDensity(kernel='gaussian').fit(dtlen)


def print_data_to_file(data, labels, outfile):
    out = outfile.split('.')[0]
    outfile = out + ".txt"
    with open(outfile, 'w') as OUT:
        header = ['Species'] + codon_list
        OUT.write("\t".join(header) + "\n")
        d = data.shape[0]
        for i in xrange(d):
            line = "\t".join([labels[i]] + [str(int(x))
                                            for x in data[i, :]]) + "\n"
            OUT.write(line)


def get_representation(codons, scale=False, outputfile=None, speclist=[]):
    labels = list(codons.keys())
    if speclist:
        labels = list(set(labels) & set(speclist))
    data = np.zeros((len(labels), 64), dtype=np.float)
    for (i, spec) in enumerate(labels):
        data[i, :] = [codons[spec].get(x, 0) for x in codon_list]
    # discard columns where we have zeros
    if outputfile:
        print_data_to_file(data, labels, outputfile)

    cols = np.where(np.sum(data, axis=0) == 0)[0]
    global TOTAL_CODON
    TOTAL_CODON -= len([codon_list[x] for x in cols])
    print "The following codon are discarded : ", ",".join([codon_list[x] for x in cols])
    data = np.delete(data, cols, axis=1)
    data_len = np.sum(data, axis=0)
    data = data / data_len
    if scale:
        data = StandardScaler().fit_transform(data)
    return data, labels, data_len


def doPCA(data, n=2):
    n = min(n, data.shape[1], data.shape[0])
    pca = PCA(n_components=2)
    return pca.fit_transform(data)


def cluster(data, algo, params={}):

    print_estimate = False
    if algo.upper() == 'DBSCAN':
        algo = DBSCAN(min_samples=3, **params)
    else:
        print_estimate = True
        bandwidth = estimate_bandwidth(data, **params)
        algo = MeanShift(bandwidth=bandwidth, bin_seeding=True)
    labels = algo.fit_predict(data)
    n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
    if print_estimate:
        print('Estimated number of clusters: %d' % n_clusters_)
    #print("Silhouette Coefficient: %0.3f" % metrics.silhouette_score(data, labels))
    return labels


def plot_result(data, labels, output="output.png", algo='MeanShift', params={}):
    clusterlab = cluster(data, algo, params=params)
    unique_labels = set(clusterlab)
    colors = plt.cm.Spectral(np.linspace(0, 1, len(unique_labels)))
    fig = plt.figure()
    plt.clf()
    ax = fig.add_subplot(111)
    ax.tick_params(axis='both', which='both', labelbottom='off', labelleft='off',
                   bottom='off', left='off', top='off', right='off')
    plt.subplots_adjust(bottom=0.1)
    for k, col in zip(unique_labels, colors):
        if k == -1:
            # Black used for noise.
            col = 'k'
        class_member_mask = (clusterlab == k)
        xy = data[class_member_mask, :2]
        ax.scatter(xy[:, 0], xy[:, 1], marker='o', c=mc.rgb2hex(col), s=500)

    for l, x, y in zip(labels, data[:, 0], data[:, 1]):
        ax.annotate(l, xy=(x, y), xytext=(-12, 15), textcoords='offset points',
                    clip_on=True, multialignment='center', rasterized=True,
                    size='small', ha='right', va='center', arrowprops=dict(arrowstyle='->', connectionstyle="arc3"))

    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    plt.title('PCA and %s Clustering of genome according to codon usage' %
              algo, y=1.06)
    # plt.show()
    plt.savefig(output, bbox_inches='tight')


def plot_aa_usage(data, codontable, labels, outfile, aadiscard="", rea_table={}):
    cm = plt.get_cmap('gist_rainbow')
    cNorm = mc.Normalize(vmin=0, vmax=len(labels) - 1)
    scalarMap = mplcm.ScalarMappable(norm=cNorm, cmap=cm)
    colors = [scalarMap.to_rgba(i) for i in range(len(labels))]
    aa_list = [aa for aa in protein_letters if aa not in aadiscard]
    y = []
    for (i, spec) in enumerate(labels):
        specdata = defaultdict(int)
        spec_table = rea_table.get(spec, {})
        for (cod, l) in data[spec].items():
            if cod != '---':
                cur_aa = spec_table.get(cod, codontable.forward_table[cod])
                specdata[cur_aa.upper()] += l
        y.append([specdata.get(aa, 0) * 1.0 /
                  TOTAL_CODON for aa in aa_list])

    y = np.asarray(y)
    plt.clf()
    if SEABORN:
        df = pd.DataFrame(y, index=labels, columns=[x for x in aa_list])
        df = df.stack().reset_index().rename(
            columns={'level_0': 'species', 'level_1': 'AA', 0: 'Freq'})
        fig = plt.figure(figsize=(len(aa_list) * 0.75, len(aa_list) * 0.3))
        sw_ax = sns.swarmplot(x='AA', y='Freq', data=df, hue='species', size=4)
        sw_ax.set_ylim(0,)
        sns.despine(offset=10, trim=True)
        #fig = sw_ax.get_figure()
        sw_ax.legend(loc='upper left', bbox_to_anchor=(
            1.01, 1), ncol=2, fancybox=True, shadow=True)
        fig.savefig(outfile, bbox_inches="tight")

    else:
        fig, ax = plt.subplots(1, 1)
        y = y.T
        for i, aa in enumerate(aa_list):
            x = np.random.normal((i + 1) * 3, 0.15, len(y[i])),
            ax.scatter(x, y[i], color=colors)

        ax.tick_params(axis='both', which='both', top='off', right='off')
        leg = [mpatches.Patch(color=colors[i], label=labels[i])
               for i in range(len(labels))]
        plt.legend(handles=leg, loc='upper left',
                   bbox_to_anchor=(-0.01, -0.1), ncol=3, fancybox=True, shadow=True)
        plt.xticks(np.arange(0, len(aa_list) * 3, 3) + 3, aa_list)
        plt.subplots_adjust(bottom=0.1)
        plt.margins(0.1)
        plt.xlabel('Amino acid')
        plt.ylabel('frequency')
        plt.title('A.A. frequency in each species', y=1.06)
        plt.savefig(outfile, bbox_inches="tight")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Plot codon or amino acid usage in genome')
    parser.add_argument('--reafile', '-i', dest="reafile",
                        help="Json rea file")
    parser.add_argument('--outfile', '-o', dest="outfile",
                        default="output", help="Outfile file")
    parser.add_argument('--scale', action='store_true',
                        dest="scale", help="Scale data")
    parser.add_argument('--plot_aa_usage', action='store_true',
                        dest="aausage", help="Plot aa usage (frequency)")
    parser.add_argument('--csv', action='store_true', dest="csv",
                        help="Export codon count in csv format")
    parser.add_argument('--speclist', '-s', dest="speclist",
                        help="Only use the provided list of species.")
    parser.add_argument('--reatable', '-r', dest="reatable", help="Use a specific table for the genome listed in this file\
                        This file should be in a format similar to Coretracker's output (predictions.txt).")
    parser.add_argument('--aa_discard', '--aar', dest="aadiscard", nargs='+',
                        help="Discard the following aa from the aausage plot")
    parser.add_argument('--gcode', type=int, dest="defcode", default=4,
                        help="Default genetic code to use")
    parser.add_argument('--pca_component', type=int, dest="pca",
                        default=2, help="Number of component for pca")
    parser.add_argument('--clustering_algo', '--algo', choices=('MeanShift', 'DBSCAN'),
                        default='MeanShift', dest="algo", help="Clustering algorithm to use, default: MeanShift")
    parser.add_argument('--bandwidth_quantile', dest="ebquant", type=float,
                        default=0.3, help="Quantile value for bandwidth estimation")

    args = parser.parse_args()
    data = parse_json_file(args.reafile)
    speclist = []
    new_table = defaultdict(dict)
    if args.speclist:
        try:
            for line in open(args.speclist):
                line = line.strip()
                if line and not line.startswith('#'):
                    speclist.append(line)
        except IOError:
            speclist = re.split(';|,| |-', args.speclist)
            speclist = [x.strip() for x in speclist]

    if args.reatable:
        curr_cod = ""
        dest_aa = ""
        pattern = re.compile("^[A-Z]{3,}\s?\([A-Z]\,\s[A-Z]\)")
        with open(args.reatable) as RTable:
            for line in RTable:
                line = line.strip()
                if line and not line.startswith('#'):
                    if pattern.match(line):
                        curr_cod = line.split('(')[0].strip()
                        dest_aa = line.split(',')[-1].strip(' )')
                    else:
                        spec = line.strip().split("\t")[0].strip()
                        if curr_cod and dest_aa:
                            new_table[spec][curr_cod] = dest_aa

    dt, labels, _ = get_representation(
        data['codons'], args.scale, (args.outfile if args.csv else None), speclist)
    pca_data = doPCA(dt, args.pca)
    outfile, fmt = args.outfile, 'svg'
    if '.' in outfile:
        outfile, fmt = args.outfile.split('.')

    if args.algo == 'DBSCAN':
        params = {}
    else:
        params = {'quantile': args.ebquant}
    plot_result(pca_data, labels, outfile + "." +
                fmt, algo=args.algo, params=params)

    if args.aadiscard:
        args.aadiscard = set(itertools.chain.from_iterable(args.aadiscard))
    if args.aausage:
        codontable = CodonTable.unambiguous_dna_by_id[args.defcode]
        outfile = outfile + '_aa_usage'
        plot_aa_usage(data['codons'], codontable, labels,
                      outfile + "." + fmt, aadiscard=args.aadiscard, rea_table=new_table)

    # TODO: gene length distribution outlier and gene number ==> better in the
    # web code
