#!/usr/bin/env python

import numpy as np
import galprime as gp

from astropy.table import Table

import time
import argparse

import traceback
import os
import pebble


global logger

parser = argparse.ArgumentParser(description="Run GalPRIME simulation")
parser.add_argument("config_filename", type=str, help="Path to config file")
parser.add_argument("--max_bins", type=int, default=None, help="Maximum number of bins to process")
parser.add_argument("--run_id", type=int, default=gp.get_dt_intlabel(), help="Run ID")
parser.add_argument("--log_level", type=int, default=20, help="Logging level. 10=DEBUG, 20=INFO, 30=WARNING")
parser.add_argument("--verbose", action="store_true", help="Verbose output")
parser.add_argument("--keep_temp", action="store_true", help="Delete temporary files")
parser.add_argument("--save_objs", action="store_true", help="Save output GPrimeSingle files.")
args = parser.parse_args()

start_time = time.perf_counter()

config_filename = "myconfig.gprime"


def process_single(fn):
    gprime_single = gp.load_object(fn)
    try:
        gprime_single.process()
    except Exception as error:
        raise RuntimeError(f'{error}')
    if gprime_single.save_output:
        gp.save_object(gprime_single, fn.replace(".pkl", "_done.pkl"))

    if not args.keep_temp:
        os.remove(fn)

    return gprime_single.condensed_output()


def process_bin(b):
    cores = config["NCORES"]
    n_objects = config["MODEL"]["N_MODELS"]
    bg_indices = np.random.randint(0, len(bgs.cutouts), n_objects)
    psf_indices = np.random.randint(0, len(psfs.cutouts), n_objects)  
    
    model_template = model()

    keys, kde = gp.setup_kde(model_template, config, b.objects)

    # Create and pickle the gprime single objects
    to_process = []
    for i in range(n_objects):
        filename = f'{outfiles["TEMP"]}{run_id}_{b.bin_id()}_{i}.pkl'

        bg = bgs.cutouts[bg_indices[i]]
        psf = psfs.cutouts[psf_indices[i]]

        params = gp.sample_kde(config, keys, kde)
        params = gp.update_required(params, config)

        if mag_kde is not None:
            mag = mag_kde.resample(size=1)[0][0]
            params["MAG"] = mag

        gprime_single = gp.GPrimeSingle(config, model(), params, 
                                        bg=bg, psf=psf, logger=logger,
                                        save_output=args.save_objs,
                                        metadata={"ITERATION": i,
                                                  "BG_INDEX": bg_indices[i],
                                                  "PSF_INDEX": psf_indices[i]})
        gp.save_object(gprime_single, filename)

        to_process.append(filename)
    
    job_list = []
    with pebble.ProcessPool(max_workers=cores) as pool:
        for i in range(n_objects):
            job_list.append(pool.schedule(process_single, args=(to_process[i], ),
                                            timeout=config["TIME_LIMIT"]))
    
    good_results = []
    for i in range(len(job_list)):
        try:
            result = job_list[i].result()
            if len(result["ISOLISTS"]) != 3:
                logger.error(f"Not all profiles extracted {i}")
                continue
            good_results.append(result)
        except Exception as e:
            logger.error(f'Error processing object {i}: {e}')
            continue
    
    gp.handle_output(good_results, outfiles, config, bin_id=b.bin_id())
    # gp.save_object(good_results, f'{outfiles["ADDL_DATA"]}{run_id}_{b.bin_id()}.pkl')

    percent_good = len(good_results) / n_objects * 100
    logger.info(f"Bin {b.bin_id()}: {len(good_results)} of {n_objects} successfully finished ({percent_good} %).")


if __name__ == '__main__':
    config = gp.read_config_file(config_filename)
    outfiles = gp.gen_filestructure(config["DIRS"]["OUTDIR"])

    gp.save_object(config, f'{outfiles["ADDL_DATA"]}config_{args.run_id}.pkl')

    run_id = args.run_id
    config["RUN_ID"] = run_id

    logger = gp.setup_logging(run_id, args.log_level, log_filename=f'{config["DIRS"]["OUTDIR"]}output_{run_id}.log')

    logger.info(f"Starting run ID:{run_id}, GalPRIME Version: {gp.__version__}", )

    # Load in backgrounds, PSFS, and object catalogue
    bgs = gp.Cutouts.from_file(f'{config["FILE_DIR"]}{config["FILES"]["BACKGROUNDS"]}', logger=logger)
    psfs = gp.Cutouts.from_file(f'{config["FILE_DIR"]}{config["FILES"]["PSFS"]}', logger=logger)
    psfs.get_ra_dec(ra_key=config["PSFS"]["PSF_RA"], dec_key=config["PSFS"]["PSF_DEC"])

    mag_kde = None
    if config["FILES"]["MAG_CATALOGUE"] is not None:
        mag_table = Table.read(f'{config["FILE_DIR"]}{config["FILES"]["MAG_CATALOGUE"]}')
        mags = mag_table[config["KEYS"]["MAG"]]
        mag_kde = gp.object_kde([mags])
        logger.info(f'Loaded magnitude kde from {len(mag_table)} entries in {config["FILES"]["MAG_CATALOGUE"]}')

    table = Table.read(f'{config["FILE_DIR"]}{config["FILES"]["CATALOGUE"]}')
    table = gp.trim_table(table, config)
    logger.info(f'Loaded catalogue with {len(table)} entries')

    print(f"GalPRIME v{gp.__version__} -- Starting run ID:{run_id}")
    print(f'Logfile saved to: {config["DIRS"]["OUTDIR"]}output_{run_id}.log')

    binlist = gp.bin_catalogue(table, bin_params=config["BINS"], params=config["KEYS"], logger=logger)
    max_bins = len(binlist.bins) if args.max_bins is None else min(args.max_bins, len(binlist.bins))

    model = gp.galaxy_models[config["MODEL"]["MODEL_TYPE"]]
    logger.info(f"Using model {model.__name__}")

    # Go through the bins and process them
    for i in range(max_bins):
        b = binlist.bins[i]
        logger.info(f'Processing bin {b.bin_id()} ({i+1} of {max_bins})')
        # Process the bin
        try:
            process_bin(b)
        except KeyboardInterrupt as e:
            logger.error(f"Run interrupted by user")
            break
        except Exception as e:
            logger.error(f"Critical error processing bin {b.bin_id()}: {e}")
            if args.log_level <= 10:
                print(traceback.print_exc())
            continue
        logger.info(f"Finished processing bin {b.bin_id()}")


    time_elapsed = time.perf_counter() - start_time
    unit = "seconds" if time_elapsed < 60 else "minutes"
    if time_elapsed > 60:
        time_elapsed /= 60

    logger.info("Finished processing.")
    logger.info(f"Run complete: Time Elapsed: {time_elapsed:.2f} {unit}")

    print(f"Run complete: Time Elapsed: {time_elapsed:.2f} {unit}")
    print("Results saved to: ", config["DIRS"]["OUTDIR"])
