#!/usr/bin/env python
import argparse
import datetime
import json

def write_error(err_file, message):
    mm = f'{datetime.datetime.now()} | {message}\n'
    with open(err_file, 'a') as f:
        f.write(mm)
    import optim_esm_tools as oet
    oet.config.get_logger().error(message)


def parse_args():
    parser = argparse.ArgumentParser(description='make plots')
    parser.add_argument('--path', type=str,)
    parser.add_argument('--methods', nargs='*', default='Percentiles ProductPercentiles LocalHistory'.split())
    parser.add_argument('--variable', type=str,)
    parser.add_argument('--save_in', type=str,)
    parser.add_argument('--err_file', type=str, default='errors.txt',)
    parser.add_argument('--show', action='store_true')
    parser.add_argument('--profile_memory', action='store_true')
    parser.add_argument('--extra_opt', default='{"time_series_joined": false, "scatter_medians": true}', type=json.loads, )
    parser.add_argument('--read_ds_kw', default='{}', type=json.loads, )
    args = parser.parse_args()
    return args

def make_plot(variable, log, path, save_in, show, read_ds_kw, extra_opt, methods):
    save_kw = dict( save_in = save_in, sub_dir = None, file_types=('png', #'pdf'
                                                                   ), skip= False, )
    from optim_esm_tools.utils import setup_plt, print_versions
    print_versions()
    setup_plt()
    import optim_esm_tools.analyze.region_finding
    from optim_esm_tools.analyze.pre_process import NoDataInTimeRangeError
    constructors = [getattr(optim_esm_tools.analyze.region_finding, meth, 'not found') for meth in methods]
    if any(c=='not found' for c in constructors):
        raise ValueError(f'One or more non-existing methods {methods} ({constructors})')

    from optim_esm_tools.analyze.cmip_handler import read_ds

    log.info('Read path')
    read_ds_kw=read_ds_kw if read_ds_kw is not None else dict()
    ds = read_ds(path, **read_ds_kw)
    kw = dict(data_set=ds, variable=variable, path=None, read_ds_kw=read_ds_kw)

    # For cons in constructors
    for cons in constructors:
        result = cons(**kw, save_kw=save_kw, extra_opt=extra_opt, )
        try:
            masks_and_clusters = result.filter_masks_and_clusters(result.get_masks())
            if masks_and_clusters[0] == []:
                log.warning(f'No results for {result.__class__.__name__} - masks are empty')
                continue
        except (NoDataInTimeRangeError, NotImplementedError, RuntimeError):
             log.error(f'No match, not making {result.__class__.__name__}')
             continue
        except ValueError as e:
            if 'moving average window' in str(e):
                log.error(f'ValueError for {result.__class__.__name__} as the dataset is too short. Probably in LocalHistory and unharmful.')
                continue
            raise e
        result.show=show
        log.warning(f'{result.__class__.__name__}')
        result.workflow()

    log.warning('All done')


def get_log(args):
    from optim_esm_tools.analyze.find_matches import folder_to_dict
    from optim_esm_tools.config import get_logger
    return get_logger(folder_to_dict(args.path)['source_id'])

def main(log, args):
    try:
        make_plot(variable=args.variable,
                  log=log,
                  path=args.path,
                  save_in=args.save_in,
                  show=args.show,
                  read_ds_kw=args.read_ds_kw,
                  extra_opt=args.extra_opt,
                  methods=args.methods,
                  )
    except Exception as e:
        write_error(args.err_file, f'{args.path} | {e}')
        raise e

if __name__ == '__main__':
    args = parse_args()
    if args.profile_memory:
        from memory_profiler import memory_usage
        import time
        import numpy as np
        log = get_log(args)
        start = time.time()
        mem = memory_usage(proc=(main, (log, args,)), max_iterations=1)
        message = f"RAM usage was: {max(mem):.0f} MB (max), {np.mean(mem): .0f} MB (avg). "
        message += f'Took {time.time() - start:.1f} s = {(time.time() - start) / 3600:.2f} h '
        log.warning(message)
        print(message)
        log.info('Bye, bye')
    else:
        log = get_log(args)
        main(log, args)
        log.info('Done, bye')
