
import argparse
import numpy as np
import sys, os
sys.path.append('/groups/ahrens/home/ruttenv/code/zfish/')
sys.path.append('/groups/ahrens/home/ruttenv/python_packages/nmfx/')
import nmfx
from nmfx.parameters import Parameters
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')
from time import time
from nmfx import get_subfolders
from datetime import datetime
from zfish.image import imclass as im
import re

parser = argparse.ArgumentParser()

parser.add_argument(
    '--folder_path',
    '-fp',
    type=str,
    help="path to folder")

parser.add_argument(
    '--experiment_number',
    '-enum',
    type=int,
    help="experiment number")

parser.add_argument(
    '--channel_number',
    '-chnum',
    type=int,
    help="channel number - 0 or 1")

parser.add_argument(
    '--k',
    '-k',
    type=int,
    default=20,
    help="Number of components to extract")

parser.add_argument(
    '--l1_W',
    '-l1',
    type=float,
    default=0,
    help="l1 loss to be applied to weight matrix W")

parser.add_argument(
    '--max_iter',
    '-max_i',
    type=int,
    default=2000,
    help="maximum number of iterations")

parser.add_argument(
    '--min_loss',
    '-min_l',
    type=float,
    default=1e-3,
    help="minimum loss for convergence")

parser.add_argument(
    '--min_diff',
    '-min_d',
    type=float,
    default=1e-5,
    help="minimum difference")

parser.add_argument(
    '--batch_size',
    '-bs',
    type=int,
    default=50,
    help="batch_size")

parser.add_argument(
    '--step_size',
    '-ss',
    type=float,
    default=1e-2,
    help="step size")

parser.add_argument(
    '--print_iter',
    '-pi',
    type=int,
    default=100,
    help="print iter")

parser.add_argument(
    '--save_iter',
    '-si',
    type=int,
    default=200,
    help="save iter")

parser.add_argument(
    '--save_path',
    '-sp',
    type=str,
    default=None,
    help="path where to save results")

args = parser.parse_args()


### initialize parameters
parameters = Parameters()
parameters.l1_W = args.l1_W
parameters.max_iter = args.max_iter
parameters.min_loss = args.min_loss
parameters.batch_size = args.batch_size
parameters.step_size = args.step_size
parameters.min_diff = args.min_diff

print_iter = args.print_iter
save_iter = args.save_iter

### initialize results dictionary
results = vars(parameters)
results['k'] = args.k
results['folder_path'] = args.folder_path
results['experiment_number'] = args.experiment_number
results['channel_number'] = args.channel_number


### locate data
exps = list(get_subfolders(args.folder_path).keys())[1:]
fnum = int(args.folder_path.split('_f')[1].split('_')[0])
for exp_ in exps:
    if str(args.experiment_number) in exp_:
        exp = exp_


print('fnum: {}'.format(fnum))
print('processing ' + exp)
folder_name = args.folder_path + exp + '/'
dirs = get_subfolders(folder_name)   
dpath = dirs['ephys'] + 'valid_cell_inds.npy'

if args.save_path is None:
    save_path = dirs['factors']
else:
    save_path = str(Path(args.data_path).parent) + '/nmfx_results.npy'


date = datetime.today().strftime('%y%m%d_%H%M%S')
experiment_name = date + f'_results_k_{args.k}_l1_loss_weight_{args.l1_W}'
experiment_name = re.sub(r'([.?!]+) *', r'p', experiment_name)

subfolder_path = save_path +  experiment_name
os.mkdir(subfolder_path)
results_path = subfolder_path + '/' + experiment_name


val_dict = np.load(dpath, allow_pickle = True).item()
valid_inds = val_dict['valid_inds']
print('loaded valid cells')

keys = list(val_dict.keys())
if 'valid_tpts' in keys:
    print('censored timepoints found')
    valid_timepoints = val_dict['valid_tpts']
    nonvalid_timepoints = val_dict['bad_tpts'] 
    full_tpts_nan = val_dict['full_tpts_nan']
else:
    valid_timepoints = None
    nonvalid_timepoints = None
    full_tpts_nan = None



### load data
t0 = time()
mk = im.Mika(dirs = dirs, ch = args.channel_number)
dims = mk.dims
print('loading data...')
if valid_timepoints is None:
    valid_timepoints = np.arange(mk.t)

fpath = dirs['ephys'] + 'X' + str(args.channel_number) + '.npy'

 
if os.path.isfile(fpath):
    print('found existing data')    
    X = np.load(fpath)  
else:
    mk.load_celldata()
    X = mk.df[:, valid_timepoints]
    X = X[valid_inds].T
    print('data loaded')
    np.save(fpath, X)
    print('saved reduced data for the future')

t1 = time()
tdiff = np.round((t1-t0)/60, 3)
print(f'loaded data in {tdiff} mins')

t, d, = X.shape
print('data loaded')
print(f't: {t}, d: {d}')

results['t'] = t
results['d'] = d
results['valid_timepoints'] = valid_timepoints
results['full_tpts_nan'] = full_tpts_nan
results['valid_inds'] = valid_inds


### ensure data is positive
X_min = np.min(X)
if X_min<0: 
    print(f'xmin: {X_min}')
    X -= X_min # force to positive values

### run nmfx
H, W, log = nmfx.nmf(X, args.k, parameters, print_iter = print_iter, save_iter = save_iter, save_path = results_path, initial_values= None)
t1 = time()
tdiff = np.round(t1-t0)/60

print(f'total execution time: {tdiff}')

results['H'] = H
results['W'] = W
results['loss'] = log.total_loss
results['execution_time'] = tdiff
results['valid_inds'] = valid_inds
results['X_min'] = X_min

np.save(results_path, results)
print('finished - results saved!')

