#!python

import json
import argparse

import numpy as np
import networkx as nx
from networkx.readwrite import json_graph

import trilearn.auxiliary_functions
import trilearn.graph.graph as libg
import trilearn.distributions.g_intra_class as gic
from trilearn.graph.decomposable import gen_AR_graph

np.set_printoptions(precision=1)
# G-Intra class (AR(2))


def main(sigma2, rho, n_samples, n_dim, graph_dir, data_dir, precmat_dir, **args):
    graph = gen_AR_graph(n_dim, width=2)
    X = gic.sample(graph, rho, sigma2, n_samples).T
    c = gic.cov_matrix(graph, rho, sigma2)

    hm_true = nx.to_numpy_array(graph, nodelist=list(range(n_dim)))
    trilearn.auxiliary_functions.plot_matrix(hm_true, graph_dir + "/intraclass_p_" + str(n_dim) +
                   "_heatmap", "eps", "True graph")
    libg.plot(graph, graph_dir + "/intraclass_p_" +
                    str(n_dim)+"_graph.eps")
    np.savetxt(data_dir + "/intraclass_p_"+str(n_dim)+"_sigma2_"+str(sigma2) +
               "_rho_"+str(rho)+"_n_"+str(n_samples)+".csv", X, delimiter=',')
    np.savetxt(precmat_dir + "/intraclass_p_"+str(n_dim)+"_sigma2_"+str(sigma2) +
               "_rho_"+str(rho)+"_omega.txt", c.I, delimiter=',')

    with open(graph_dir +
              "/intraclass_p_"+str(n_dim)+".json", 'w') as outfile:
        js_graph = json_graph.node_link_data(graph)
        json.dump(js_graph, outfile)

    print("Generated files:")
    print((graph_dir + "/intraclass_p_"+str(n_dim)+"_heatmap.eps"))
    print((graph_dir + "/intraclass_p_"+str(n_dim)+"_graph.eps"))
    print((data_dir + "/intraclass_p_"+str(n_dim)+"_sigma2_"+str(sigma2) + \
        "_rho_"+str(rho)+"_n_"+str(n_samples)+".csv"))
    print((precmat_dir + "/intraclass_p_"+str(n_dim) + \
        "_sigma2_"+str(sigma2)+"_rho_"+str(rho)+"_omega.txt"))
    print((graph_dir + "/intraclass_p_"+str(n_dim)+".json"))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Generates samples from a graph "
                                                "intra-class model and saves the precision matrix.")
    parser.add_argument(
        '--sigma2',
        type=float, required=True,
        help='Variance'
    )
    parser.add_argument(
        '--rho',
        type=float, required=True,
        help="Correlation coefficient"

    )
    parser.add_argument(
        '-n', '--n_samples',
        type=int, required=True,
        help="Number of normal samples"
    )
    parser.add_argument(
        '-p', '--n_dim',
        type=int, required=True,
        help="Number of dimensions"
    )
    parser.add_argument(
        '--graph_dir',
        required=False, default=".",
        help="Directory where to save the graph file"
    )
    parser.add_argument(
        '--data_dir',
        required=False, default=".",
        help="Directory where to save the data file"
    )
    parser.add_argument(
        '--precmat_dir',
        required=False, default=".",
        help="Directory where to save the"
    )

    args = parser.parse_args()
    main(**args.__dict__)

