#! /usr/bin/env python

import sys
import itertools
import csv
from optparse import OptionParser
import numpy as np
from minepy import MINE, __version__


def pearson(x, y):
    """Returns the Pearson correlation between x and y.
    """

    xa = np.asarray(x, dtype=np.float)
    ya = np.asarray(y, dtype=np.float)
    if xa.shape[0] != ya.shape[0]:
        raise ValueError("x, y: shape mismatch")

    xc = xa - np.mean(xa)
    yc = ya - np.mean(ya)
    a = np.sum(xc*yc)
    b = np.sqrt(np.sum(xc**2))
    c = np.sqrt(np.sum(yc**2))

    return a / (b * c)


def compute_stats(data, varname, idx1, idx2, mine, writer):
    """Compute the statistics between data[idx] and data[idx2].
    Writes MIC, MIC-r^2, MAS, MEV, MCN, MCN_GENERAL and Pearson on writer.
    """

    print "%s vs %s..." % (varname[idx1], varname[idx2])
    w = (data[idx1] != '') & (data[idx2] != '')
    print w
    x, y = data[idx1][w].astype(np.float), data[idx2][w].astype(np.float)
    mine.compute_score(x, y)
    r = pearson(x, y)
    writer.writerow([varname[idx1], varname[idx2], mine.mic(),
                     mine.mic()-r**2, mine.mas(), mine.mev(),
                     mine.mcn(0), mine.mcn_general(), r])


def main():
    description = "MINE Python v. %s [Homepage: minepy.sf.net]." \
        " The mine script compares by default all pairs of variables against each other." \
        " It writes an output file where each column contains MIC (strength)," \
        " MIC-r^2 (nonlinearity), MAS (non-monotonicity), MEV (functionality)," \
        " MCN (complexity, eps=0), MCN_GENERAL (complexity, eps=1-MIC) and Pearson (r)." \
        " The input must be a comma-separated values file where the first column" \
        " must contain the variable names. Each variable must have the same" \
        " numer of samples." % __version__

    usage = "%prog infile [-a <alpha>] [-c <c>] [-o <file>] [-m <var index>]" \
        " [-p <var1 index> <var2 index>]"

    parser = OptionParser(description=description, usage=usage)
    parser.add_option("-a", "--alpha", dest="alpha",
        help="the exponent in B(n) = n^alpha (default: %default.)"
        " alpha must be in (0, 1.0]",
        metavar="<alpha>", type="float", default=0.6)
    parser.add_option("-c", "--clumps", dest="c",
        help= "determines how many more clumps there will be than"
        " columns in every partition. Default value is %default,"
        " meaning that when trying to draw Gx grid lines on"
        " the x-axis, the algorithm will start with at most"
        " %default*Gx clumps (default: %default). c must be > 0",
        metavar="<c>", type="float", default=15)
    parser.add_option("-o", "--output", dest="outfile",
        help="output filename (default: %default)",
        metavar="<file>", type="str", default="mine_out.csv")
    parser.add_option("-m", "--master", dest="master",
        help="variable <var index> vs. all"
        " <var index> must be in [1, number of variables in file]",
        metavar="<var index>", type="int", nargs=1)
    parser.add_option("-p", "--pair", dest="pair",
        help="variable <var1 index> vs. variable <var2 index>"
        " <var1 index> and <var2 index> must be in [1, number of variables in file]",
        metavar="<var1 index> <var2 index>", type="int", nargs=2)

    (options, args) = parser.parse_args()

    if len(args) != 1:
        parser.print_help()
        sys.exit(2)

    if options.master and options.pair:
        parser.error("options -m and -p are mutually exclusive")
        sys.exit(2)

    try:
        mine = MINE(alpha=options.alpha, c=options.c)
    except ValueError, e:
        print e
        sys.exit(1)

    data = np.loadtxt(args[0], delimiter=',', dtype=str)
    varname = data[:, 0].tolist()
    data = data[:, 1:]
    p, n = data.shape

    print "Read file %s: %d variables, %d samples" % (args[0], p, n)

    # check options.master if not None
    if options.master is not None:
        if options.master < 1 or options.master > p:
            print "wrong variable index in option -m (%d)." \
                  " Index must be in [1, %d]." % (options.master, p)
            sys.exit(1)

    # check options.pair if not None
    if options.pair is not None:
        if (options.pair[0] < 1 or options.pair[0] > p) or \
           (options.pair[1] < 1 or options.pair[1] > p):
           print "wrong variable index in option -p (%d %d)." \
                 " Indexes must be in [1, %d]." \
                 % (options.pair[0], options.pair[1], p)
           sys.exit(1)

    # initialize the output file
    fout = open(options.outfile, "w")
    writer = csv.writer(fout, delimiter=',', lineterminator='\n')
    writer.writerow(["X", "Y", "MIC (strength)", "MIC-r^2 (nonlinearity)",
                     "MAS (non-monotonicity)", "MEV (functionality)",
                     "MCN (complexity, eps=0)", "MCN_GENERAL (complexity, eps=1-MIC)",
                     "Pearson (r)"])

    if options.master is not None:
        idx1 = options.master-1
        idx_all = [elem for elem in range(p) if elem != idx1]
        for idx2 in idx_all:
            compute_stats(data, varname, idx1, idx2, mine, writer)

    elif options.pair is not None:
        idx1, idx2 = options.pair[0]-1, options.pair[1]-1
        compute_stats(data, varname, idx1, idx2, mine, writer)

    else: # default: all pairs of variables against each other
        for idx1 in range(p):
            for idx2 in range(idx1+1, p):
                compute_stats(data, varname, idx1, idx2, mine, writer)

    fout.close()


if __name__ == "__main__":
    main()
