#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright (C) California Institute of Technology 2023
#
# This file is part of pygwb.

from pycondor import Dagman
import pycondor
import logging
from pygwb.workflow import Job, Dagman, config_override, config_write, makedir
from pygwb.workflow import (create_pygwb_pipe_job,
                   create_pygwb_combine_job,
                   create_pygwb_stats_job,
                   create_pygwb_html_job)

import numpy as np
import os
import glob
import argparse
import configparser
import pandas as pd
from pygwb.html import pygwb_html

from gwpy.segments import DataQualityDict, DataQualityFlag

__process_name__ = 'pygwb_create_workflow'
__version__ = '0.0.1'


def make_workflow(args=None):

    config = configparser.ConfigParser()
    config.read(args.configfile)

    if config.has_option('general', 'min_job_dur'):
        min_dur = int(config['general']['min_job_dur'])
    else:
        min_dur = 0
    if config.has_option('general', 'max_job_dur'):
        max_dur = int(config['general']['max_job_dur'])
    else:
        max_dur = 1000000

    # =====
    # set up segments

    t0 = int(config['general']['t0'])
    tf = int(config['general']['tf'])
    dur = int(tf-t0)

    ifos = config['general']['ifos'].split(' ')

    logging.info('Downloading science flags...')
    sci_flag = DataQualityDict.query_dqsegdb(
        [i+':'+config['data_quality']['science_segment'] for i in ifos],
        t0, tf
        ).intersection().active

    if config.has_option('data_quality', 'veto_definer'):
        logging.info('Downloading vetoes...')
        vetoes = DataQualityDict.from_veto_definer_file(
            config['data_quality']['veto_definer']
            )
        cat1_vetoes = DataQualityDict({v: vetoes[v] for v in vetoes if (vetoes[v].category ==1) and (v[:2] in ifos)})
        cat1_vetoes.populate()

        cat1_vetoes = cat1_vetoes.intersection().active

        seglist = sci_flag - cat1_vetoes
    else:
        seglist = sci_flag

    seglist_pruned = []
    for seg in seglist:
        start = int(seg[0])
        end = int(seg[1])
        dur = int(end-start)
        if dur < min_dur:
            continue
        elif dur > max_dur:
            n_edges = int(dur/max_dur+2)
            edges = np.linspace(start, end, n_edges, endpoint=True)
            for i in range(n_edges-1):
                seglist_pruned.append([int(edges[i]),int(edges[i+1])])
        else:
            seglist_pruned.append([start, end])
            
            
        

    # =====

    logging.info('Writing DAG...')
    dagman = Dagman(name='pygwb_dagman',
                    basedir=args.basedir)
    dagman.make_dir()

    config = config_override(config, args.config_overrides)
    config_path = os.path.join(dagman.base_dir, 'config.ini')
    config_write(config, config_path)

    all_jobs = []
    pipe_jobs = []

    cumu_dir = os.path.join(dagman.result_dir, 'combined_results/')
    makedir(cumu_dir)

    first_start = int(seglist_pruned[0][0])
    last_start = 0
    for seg in seglist_pruned:
        start = int(seg[0])
        end = int(seg[1])
        dur = int(end-start)
        seg_dir = os.path.join(dagman.result_dir, f'{start}-{dur}/')
        makedir(seg_dir)

        last_start = start

        pipe_job = create_pygwb_pipe_job(dagman, config, parents=[], 
            output_path=seg_dir, 
            t0=start, tf=end
            )
        all_jobs.append(pipe_job)
        pipe_jobs.append(pipe_job)

        # extra vars
        param_file = os.path.join(seg_dir, 'parameters_final.ini') 
        combine_job = create_pygwb_combine_job(dagman, config,
            t0=start, tf=end, 
            parents=[pipe_job], output_path=seg_dir,
            data_path=seg_dir, param_file=param_file)
        all_jobs.append(combine_job)
        
        # extra vars
        fref = int(config['pygwb_pipe']['fref']) 
        alpha =  float(config['pygwb_pipe']['alpha'])
        csd_file = os.path.join(seg_dir, 
            f'point_estimate_sigma_spectra_alpha_{alpha:.1f}_fref_{int(fref)}_{start}-{start}.npz') 
        dsc_file = os.path.join(seg_dir, 
            f'delta_sigma_cut_{start}-{start}.npz')
        stats_job = create_pygwb_stats_job(dagman, config, 
                    t0=start, tf=end,
            parents=[combine_job], output_path=seg_dir,
            csd_file=csd_file, 
            dsc_file=dsc_file, 
            param_file=param_file)
        all_jobs.append(stats_job)

        # set up symlinks for pipe results
        npz_files = [
            f'csds_psds_{start}-{end}.npz',
            f'point_estimate_sigma_{start}-{end}.npz',
            ] 
        for npz in npz_files:
            os.symlink(
                os.path.join(seg_dir, npz), 
                os.path.join(cumu_dir, npz))

    # make cumulative results
    cumu_combine_job = create_pygwb_combine_job(dagman, config,
        t0=t0, tf=tf,
        parents=pipe_jobs, output_path=cumu_dir,
        data_path=cumu_dir, param_file=param_file)
    all_jobs.append(cumu_combine_job)

    csd_file = os.path.join(cumu_dir,
        f'point_estimate_sigma_spectra_alpha_{alpha:.1f}_fref_{int(fref)}_{first_start}-{last_start}.npz')
    dsc_file = os.path.join(cumu_dir,
        f'delta_sigma_cut_{first_start}-{last_start}.npz')
    print(dsc_file)
    cumu_stats_job = create_pygwb_stats_job(dagman, config,
        t0=t0, tf=tf,
        parents=[cumu_combine_job], output_path=cumu_dir,
        csd_file=csd_file,
        dsc_file=dsc_file,
        param_file=param_file)
    all_jobs.append(cumu_stats_job)

    pygwb_html(dagman.base_dir, config=config_path)

    # write out segment list and job file
    seg_dir = os.path.join(dagman.result_dir, 'segment_lists/')
    makedir(seg_dir)
    start_list = np.array(seglist_pruned).T[0]
    end_list = np.array(seglist_pruned).T[1]
    dur_list = end_list - start_list
    ones_list = np.ones(len(dur_list), dtype=int)
    df = pd.DataFrame(data={'one': ones_list,
                            'start': start_list,
                            'end': end_list,
                            'dur': dur_list})
    jobfile_loc = os.path.join(seg_dir, "jobfile.dat") 
    df.to_csv(jobfile_loc, sep = " ", header=False, index=False)

    # make html page
    html_job = create_pygwb_html_job(dagman, config, 
        t0=t0, tf=tf, parents=all_jobs, 
        output_path=dagman.base_dir, plot_path=dagman.result_dir)

    dagman.build(fancyname=False)

    if args.submit:
        logging.info('Submitting to condor...')
        dagman.submit_dag()

    logging.info('Done')


def main(args=None):
    logging.basicConfig(format='%(asctime)s:%(levelname)s : %(message)s',
                    level=logging.WARNING)
    logger = logging.getLogger()

    parser = argparse.ArgumentParser(description=__doc__,
                                     prog=__process_name__)
    parser.add_argument('-v', '--verbose', action='store_true',
                        help='increase verbose output')
    parser.add_argument('--debug', action='store_true',
                        help='increase verbose output to debug')
    parser.add_argument('-V', '--version', action='version',
                        version=__version__)
    parser.add_argument('--basedir', type=os.path.abspath, required=True,
            help='Build directory for condor logging')
    parser.add_argument('--configfile', type=str, required=True,
            help='config file with workflow parameters')
    parser.add_argument('--config-overrides', type=str, nargs='+',
            help='config overrides listed in the form SECTION:KEY:VALUE')
    parser.add_argument('--submit', action='store_true',
            help='Submit workflow automatically')
    args = parser.parse_args(args=args)

    if args.verbose:
        logger.setLevel(logging.INFO)
    if args.debug:
        logger.setLevel(logging.DEBUG)

    # call the above function
    make_workflow(args=args)

# allow be be run on the command line
if __name__ == "__main__":
    main(args=None)
