import nmfx
import argparse
import numpy as np
from nmfx.parameters import Parameters
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

parser = argparse.ArgumentParser()

parser.add_argument(
    '--data_path',
    '-dp',
    type=str,
    help="path to numpy array containing data (txn)")

parser.add_argument(
    '--k',
    '-k',
    type=int,
    default=20,
    help="Number of components to extract")

parser.add_argument(
    '--l1_W',
    '-l1',
    type=float,
    default=0,
    help="l1 loss to be applied to weight matrix W")

parser.add_argument(
    '--max_iter',
    '-max_i',
    type=float,
    default=20000,
    help="maximum number of iterations")

parser.add_argument(
    '--min_loss',
    '-min_l',
    type=float,
    default=1e-2,
    help="minimum loss for convergence")

parser.add_argument(
    '--batch_size',
    '-bs',
    type=int,
    default=200,
    help="batch_size")

parser.add_argument(
    '--step_size',
    '-ss',
    type=float,
    default=1e-2,
    help="step size")

parser.add_argument(
    '--save_path',
    '-sp',
    type=str,
    default=None,
    help="path where to save results")

args = parser.parse_args()

if args.save_path is None:
    args.save_path = str(Path(args.data_path).parent) + '/nmfx_results.npy'

### initialize parameters
parameters = Parameters()
parameters.l1_W = args.l1_W
parameters.max_iter = args.max_iter
parameters.min_loss = args.min_loss
parameters.batch_size = args.batch_size
parameters.step_size = args.step_size

### initialize results dictionary
results = vars(parameters)
results['k'] = args.k
results['data_path'] = args.data_path

### load data
X = np.load(args.data_path)
t, d, = X.shape
print('data loaded')
print(f't: {t}, d: {d}')

results['t'] = t
results['d'] = d


### normalize data
X_min = np.min(X)
if X_min<0: 
    print(f'xmin: {X_min}')
    X -= X_min # force to positive values

### run nmfx
H, W, log = nmfx.nmf(X, args.k, parameters)

results['H'] = H
results['W'] = W
results['loss'] = log.total_loss

np.save(args.save_path, results)

