#!/usr/bin/env python
import argparse
import json
from pathlib import Path
import numpy as np
import PBCT
from PBCT.classes import DEFAULTS


def parse_args():
    arg_parser = argparse.ArgumentParser(
        description=(
            "Fit a PBCT model to data or use a trained model to predict new"
            " results. Input files and options may be provided either with comm"
            "and-line options or by a json config file (see --config option)."
        ),
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    arg_parser.add_argument(
        'action', choices=('fit', 'predict', 'train_test', 'xval'), nargs='?',
        help=(
            'fit: Use input data to train a PBCT. '
            'predict: Predict interaction between input instances. '
            'train_test: Split data between the 4 train/test sets, train and te'
            'st a PBCT. '
            'xval: run a 2D k-fold cross validation with the given data.'
        ))

    arg_parser.add_argument(
        '--config',
        help=('Load options from json file. File example: {'
              '\n\t"path_model": "/path/to/save/model.json",'
              '\n\t"fit": "true",'
              '\n\t"XX": ["/path/to/X1.csv", "/path/to/X2.csv"],'
              '\n\t"Y": "/path/to/Y.csv",'
              #'\n\t"XX_names": ["/path/to/X1_names.txt",  "/path/to/X2_names.txt"],'
              #'\n\t"XX_col_names": ["/path/to/X1_col_names.txt", "/path/to/X2_col_names.txt"],'
              '\n}.'
              ' Multiple dicts in a list will caus'
              'e this script to run multiple times.'))

    arg_parser.add_argument(
        '--XX', nargs='+',
        help=('Paths to .csv files containing rows of numerical attributes'
              ' for each axis\' instance.'))
    arg_parser.add_argument(
        '--XX_names', nargs='+',
        help=('Paths to files containing string identifiers for each instance'
              ' for each axis, being one file for each axis.'))
    arg_parser.add_argument(
        '--XX_col_names', nargs='+',
        help=('Paths to files containing string identifiers for each attribute'
              'column, being one file for each axis.'))
    arg_parser.add_argument(
        '--Y',
        help=('If fitting the model to data, it represents the pat'
              'h to a .csv file containing the known interaction matrix be'
              'tween rows and columns data.'
              'If predicting, Y is the destination path for the pr'
              'edicted values, formatted as an interaction matrix in the s'
              'ame way described.'))
    arg_parser.add_argument(
        '--path_model', default=DEFAULTS['path_model'],
        help=('When fitting: path to the location where the model will be '
              'saved. When predicting: the saved model to be used.'))
    arg_parser.add_argument(
        '--max_depth', default=DEFAULTS['max_depth'],
        help=('Maximum PBCT depth allowed. -1 will disable this stopping cr'
              'iterion.'))
    arg_parser.add_argument(
        '--min_samples_leaf', default=DEFAULTS['min_samples_leaf'],
        help=('Minimum number of sample pairs in the training set required'
              ' to arrive at each leaf.'))
    # arg_parser.add_argument(
    #     '--simple_mean', action='store_true',
    #     help=('If provided, the prototype function will always return the '
    #           'mean value over the entire sub interaction matrix of the leaf'
    #           ', not considering possible known instances.'))
    arg_parser.add_argument(
        '--verbose', '-v', action='store_true',
        help='Show more detailed output')
    arg_parser.add_argument(
        '--outdir', default='PBCT_results',
        help='Where to save results.')
    arg_parser.add_argument(
        '--k', '-k', type=int, nargs='+', default=[3],
        help='Number of folds for cross-validation.')
    arg_parser.add_argument(
        '--diag', action='store_true',
        help=('Use independent TrTc sets for cross-validation, i.e. with no ove'
              'rllaping rows or columns.'))
    arg_parser.add_argument(
        '--test_size', type=float, nargs='+',
        help=('If between 0.0 and 1.0, represents the proportion of the dataset ' 
             'to include in the TrTc split for each axis, e.g.: .3 .5 means 30%% '
             'of the rows and 50%% of the columns will be used as the TrTc set.'
             ' If >= 1, represents the absolute number of test samples in each a'
             'xis. If None, the values are set to the complements of train_size'
             '. If a single value v is given, it will be interpreted as (v, v).'
             ' If train_size is also None, it will be set to 0.25.'))
    arg_parser.add_argument(
        '--train_size', type=float, nargs='+',
        help='Same as test_size, but refers to the LrLc set dimensions.')
    arg_parser.add_argument(
        '--njobs', type=int,
        help='How many processes to spawn when cross-validating.')
    arg_parser.add_argument(
        '--random_state', type=int,
        help='Random seed to use.')

    # TODO:
    # arg_parser.add_argument(
    #     '--visualize', default=DEFAULTS['path_rendering'],
    #     help=('If provided, path to where a visualization of the trained t'
    #           'ree will be saved.'))

    # TODO:
    #arg_parser.add_argument(
    #    '--fromdir',
    #    help=('Read data from directory instead. In such case, filenames must be:'
    #          '\tX1, X2, Y, X1_names, X2_names, X1_col_names and X2_col_names,\n'
    #          'with any file extension. *_names files are optional.'))

    return arg_parser.parse_args()


def main(**kwargs):
    """CLI for using PBCTs.

    Command-line utility for training a PBCT or using a trained model to predic
    t new interactions. See `parse_args()` or use --help for parameters' descri
    ption.
    """
    np.random.seed(kwargs['random_state'])

    if kwargs['config'] is not None:
        # If config file is a single dict, load its options and proceed.
        # If it's a list of dicts, run this function for each of them.
        with open(kwargs['config']) as config_file:
            config = json.load(config_file)
        if isinstance(config, dict):
            kwargs.update(config)
        elif isinstance(config, list):
            # Ensure we are not caugth in the top conditional again.
            kwargs['config'] = None
            for jsonkwargs in config:
                main(**(kwargs | jsonkwargs))
            return

    print('Loading data...')
    XX = [np.loadtxt(p, delimiter=',') for p in kwargs['XX']]
    XX_names, XX_col_names = None, None
    # if kwargs['XX_names']:
    #     XX_names = [np.loadtxt(p) for p in kwargs['XX_names']]
    if kwargs['XX_col_names']:
        XX_col_names = [np.loadtxt(p) for p in kwargs['XX_col_names']]

    if kwargs['action'] == 'predict':
        print('Loading model...')
        Tree = PBCT.load_model(kwargs['path_model'])
        print('Predicting values...')
        predictions = Tree.predict(
            XX,
            # simple_mean=kwargs['simple_mean'],
            verbose=kwargs['verbose'],
        )
        print('Saving...')
        np.savetxt(kwargs['Y'], predictions, delimiter=',', fmt='%d')
        print('Done.')

    else:
        outdir = Path(kwargs['outdir'])
        Y = np.loadtxt(kwargs['Y'], delimiter=',')
        Tree = PBCT.PBCT(
            min_samples_leaf=kwargs['min_samples_leaf'],
            max_depth=kwargs['max_depth'],
            verbose=kwargs['verbose'],
        )

        if kwargs['action'] == 'fit':
            print('Training PBCT...')
            Tree.fit(XX, Y, X_names=XX_col_names)  # FIXME: Confusing variable names.
            print('Saving model...')
            Tree.save(kwargs['path_model'])
            print('Done.')

        elif kwargs['action'] == 'train_test':
            outdir.mkdir()

            test_size = kwargs['test_size']
            if isinstance(test_size, list) and len(test_size) == 1:
                test_size = test_size[0]

            train_size = kwargs['train_size']
            if isinstance(train_size, list) and len(train_size) == 1:
                test_size = train_size[0]

            split, pred = PBCT.train_test.split_fit_test(
                XX, Y, Tree, test_size=test_size,
                train_size=train_size,
            )
            Tree.save(outdir/'model.dict.pickle.gz')

            PBCT.train_test.save_split(split, outdir/'data')

            dir_pred = outdir/'pred'
            dir_pred.mkdir()
            for LT, data in pred.items():
                np.savetxt(dir_pred/(LT + '.csv'), data, delimiter=',')

        elif kwargs['action'] == 'xval':
            # NOTE: models are not being saved.
            outdir.mkdir()
            k = kwargs['k']
            if isinstance(k, list) and len(k) == 1:
                k = k[0]
            splits, _, preds = PBCT.train_test.cross_validate_2D(
                XX, Y, Tree, k=k, diag=kwargs['diag'], njobs=kwargs['njobs'],
            )
            dir_folds = outdir/'folds'
            dir_preds = outdir/'preds'
            dir_folds.mkdir()
            dir_preds.mkdir()

            for i, split in enumerate(splits):
                PBCT.train_test.save_split(split, dir_folds/f'fold{i+1}')
            for i, pred in enumerate(preds):
                dir_pred = dir_preds/f'fold{i+1}'
                dir_pred.mkdir()
                for LT, data in pred.items():
                    np.savetxt(dir_pred/(LT + '.csv'), data, delimiter=',')


if __name__ == '__main__':
    args = parse_args()
    main(**vars(args))
