#!python

import argparse
import numpy as np
import nibabel as nib
from sklearn.linear_model import BayesianRidge


def get_mask(path):
    # Generate mask dataset
    if path:
        img = np.asarray(nib.load(path)._dataobj)
        img[img > 0] = 1
        return img
    else:
        return None


def norm(data):
    return (data - data.mean())/data.std()


def get_affine(path):
    return nib.load(path).affine.copy()


def dualRegression(args):
    input_data = np.asarray(nib.load(args.input).dataobj)
    ics = np.asarray(nib.load(args.regressor).dataobj)
    mask = get_mask(args.mask)
    comp_n = args.index

    x, y, z, t = input_data.shape
    ics = ics.reshape(x * y * z, ics.shape[-1])
    mask = mask.reshape(x * y * z)
    input_data = input_data.reshape(x * y * z, t)

    # Skull stripping
    input_data[mask == 0] = 0

    # Dual regression (Bayesian Ridge Regression)
    reg = BayesianRidge(compute_score=True)
    reg.fit(input_data, ics[:, comp_n])
    reg.fit(input_data.T, reg.coef_)
    nib.Nifti1Image(norm(reg.coef_).reshape(x, y, z),
                    get_affine(args.input)).to_filename(args.output)


def main():
    parser = argparse.ArgumentParser(prog='pynit',# formatter_class=SmartFormatter,
                                     description="Collection of PyNIT processing and analyzing functions")
    parser.add_argument("-v", "--version", action='version', version ='%(prog)s 0.0.1')

    subparsers = parser.add_subparsers(title='Sub-commands',
                                       description='Something description',
                                       help='description',
                                       dest='function',
                                       metavar='command')

    dualreg = subparsers.add_parser("dualreg", help='DualRegression')
    dualreg.add_argument("-i", "--input", help="input file", type=str)
    dualreg.add_argument("-o", "--output", help="output file", type=str)
    dualreg.add_argument("-m", "--mask", help="mask file", type=str)
    dualreg.add_argument("-r", "--regressor", help="regressor file", type=str)
    dualreg.add_argument("-d", "--index", help="index of regressor", type=int, default=0)
    dualreg.set_defaults(func=dualRegression)

    args = parser.parse_args()
    if args.function == 'dualreg':
        args.func(args)
    else:
        pass


if __name__ == '__main__':
    main()