#!/usr/bin/env python3
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0OA
#
# Authors:
# - Wen Guan, <wen.guan@cern.ch>, 2024


"""
Run workflow.
"""

from __future__ import print_function

import argparse
import argcomplete
import base64
# import json
import logging
import pickle
import os
import sys
import time
import traceback

from idds.common.utils import json_dumps, json_loads, setup_logging, decode_base64
# from idds.common.utils import merge_dict
from idds.iworkflow.version import release_version
from idds.iworkflow.workflow import Workflow
from idds.iworkflow.work import Work


setup_logging(__name__, stream=sys.stdout)


def get_context_args(context, original_args, current_job_kwargs):
    func_name, pre_kwargs, args, kwargs, multi_jobs_kwargs_list = None, None, None, None, None
    if original_args:
        original_args = json_loads(original_args)
        func_name, pre_kwargs, args, kwargs, multi_jobs_kwargs_list = original_args

    if args:
        args = pickle.loads(base64.b64decode(args))
    if pre_kwargs:
        pre_kwargs = pickle.loads(base64.b64decode(pre_kwargs))
    if kwargs:
        kwargs = pickle.loads(base64.b64decode(kwargs))
    if multi_jobs_kwargs_list:
        multi_jobs_kwargs_list = [pickle.loads(base64.b64decode(k)) for k in multi_jobs_kwargs_list]

    if current_job_kwargs:
        if current_job_kwargs == "${IN/L}":
            logging.info("current_job_kwargs == original ${IN/L}, is not set")
        else:
            try:
                current_job_kwargs = json_loads(current_job_kwargs)

                if current_job_kwargs:
                    current_job_kwargs = pickle.loads(base64.b64decode(current_job_kwargs))

                # current_job_kwargs = current_job_kwargs
                # if current_job_kwargs and isinstance(current_job_kwargs, dict):
                #     # kwargs = merge_dict(kwargs, current_job_kwargs)
                #     kwargs.update(current_job_kwargs)
            except Exception as ex:
                logging.error("Failed to update kwargs: %s" % ex)
    return context, func_name, pre_kwargs, args, kwargs, multi_jobs_kwargs_list, current_job_kwargs


def run_workflow(context, original_args, current_job_kwargs):
    context, func_name, pre_kwargs, args, kwargs, multi_jobs_kwargs_list, current_job_kwargs = get_context_args(context, original_args, current_job_kwargs)
    logging.info("context: %s" % context)
    logging.info("func_name: %s" % func_name)
    logging.info("pre_kwargs: %s" % pre_kwargs)
    logging.info("args: %s" % str(args))
    logging.info("kwargs: %s" % kwargs)
    logging.info("multi_jobs_kwargs_list: %s" % str(multi_jobs_kwargs_list))
    logging.info("current_job_kwargs: %s" % str(current_job_kwargs))

    context.initialize()
    context.setup_source_files()

    workflow = Workflow(func=func_name, pre_kwargs=pre_kwargs, args=args, kwargs=kwargs, multi_jobs_kwargs_list=multi_jobs_kwargs_list, current_job_kwargs=current_job_kwargs, context=context)
    logging.info("workflow: %s" % workflow)
    with workflow:
        ret = workflow.run()
    logging.info("run workflow result: %s" % str(ret))
    return 0


def run_work(context, original_args, current_job_kwargs):
    context, func_name, pre_kwargs, args, kwargs, multi_jobs_kwargs_list, current_job_kwargs = get_context_args(context, original_args, current_job_kwargs)
    logging.info("context: %s" % context)
    logging.info("func_name: %s" % func_name)
    logging.info("pre_kwargs: %s" % pre_kwargs)
    logging.info("args: %s" % str(args))
    logging.info("kwargs: %s" % kwargs)
    logging.info("multi_jobs_kwargs_list: %s" % str(multi_jobs_kwargs_list))
    logging.info("current_job_kwargs: %s" % str(current_job_kwargs))

    context.initialize()
    context.setup_source_files()

    work = Work(func=func_name, pre_kwargs=pre_kwargs, args=args, kwargs=kwargs, multi_jobs_kwargs_list=multi_jobs_kwargs_list, current_job_kwargs=current_job_kwargs, context=context)
    logging.info("work: %s" % work)
    ret = work.run()
    logging.info("run work result: %s" % str(ret))
    return 0


def run_iworkflow(args):
    if args.context:
        context = decode_base64(args.context)
        context = json_loads(context)
        # logging.info(context)
        # context = str(binascii.unhexlify(args.context).decode())
    else:
        context = None
    if args.original_args:
        original_args = decode_base64(args.original_args)
        # logging.info(original_args)
        # orginal_args = str(binascii.unhexlify(args.original_args).decode())
    else:
        original_args = None
    if args.current_job_kwargs:
        # logging.info(args.current_job_kwargs)
        # current_job_kwargs = str(binascii.unhexlify(args.current_job_kwargs).decode())
        current_job_kwargs = decode_base64(args.current_job_kwargs)
        logging.info(current_job_kwargs)
    else:
        current_job_kwargs = None

    if args.type == 'workflow':
        logging.info("run workflow")
        password = context.broker_password
        context.broker_password = '***'
        logging.info("context: %s" % json_dumps(context))
        context.broker_password = password
        logging.info("original_args: %s" % original_args)
        logging.info("current_job_kwargs: %s" % current_job_kwargs)
        exit_code = run_workflow(context, original_args, current_job_kwargs)
        logging.info("exit code: %s" % exit_code)
    else:
        logging.info("run work")
        password = context.broker_password
        context.broker_password = '***'
        logging.info("context: %s" % json_dumps(context))
        context.broker_password = password
        logging.info("original_args: %s" % original_args)
        logging.info("current_job_kwargs: %s" % current_job_kwargs)
        exit_code = run_work(context, original_args, current_job_kwargs)
        logging.info("exit code: %s" % exit_code)
    return exit_code


def custom_action():
    class CustomAction(argparse.Action):
        def __init__(self, option_strings, dest, default=False, required=False, help=None):
            super(CustomAction, self).__init__(option_strings=option_strings,
                                               dest=dest, const=True, default=default,
                                               required=required, help=help)

        def __call__(self, parser, namespace, values=None, option_string=None):
            print(values)
            # setattr(namespace, self.dest, values)
    return CustomAction


def get_parser():
    """
    Return the argparse parser.
    """
    oparser = argparse.ArgumentParser(prog=os.path.basename(sys.argv[0]), add_help=True)

    # common items
    oparser.add_argument('--version', action='version', version='%(prog)s ' + release_version)
    oparser.add_argument('--verbose', '-v', default=False, action='store_true', help="Print more verbose output.")
    oparser.add_argument('--type', dest='type', action='store', choices=['workflow', 'work'], default='workflow', help='The type in [workflow, work]. Default is workflow.')
    oparser.add_argument('--context', dest='context', help="The context.")
    oparser.add_argument('--original_args', dest='original_args', help="The original arguments.")
    oparser.add_argument('--current_job_kwargs', dest='current_job_kwargs', nargs='?', const=None, help="The current job arguments.")
    return oparser


if __name__ == '__main__':
    arguments = sys.argv[1:]

    oparser = get_parser()
    argcomplete.autocomplete(oparser)

    args = oparser.parse_args(arguments)

    try:
        if args.verbose:
            logging.getLogger().setLevel(logging.DEBUG)
        start_time = time.time()

        exit_code = run_iworkflow(args)
        end_time = time.time()
        if args.verbose:
            print("Completed in %-0.4f sec." % (end_time - start_time))
        sys.exit(exit_code)
    except Exception as error:
        logging.error("Strange error: {0}".format(error))
        logging.error(traceback.format_exc())
        sys.exit(-1)
