#!/usr/bin/env python
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0OA
#
# Authors:
# - Wen Guan, <wen.guan@cern.ch>, 2024 - 2025


"""
Run workflow.
"""

from __future__ import print_function

import argparse
# import argcomplete
import ast
import base64
import json
import logging
import pickle
import re
import os
import sys
import time
import traceback
import zlib

from collections import defaultdict

from idds.common.utils import json_dumps, json_loads, setup_logging, decode_base64
# from idds.common.utils import merge_dict
from idds.iworkflow.version import release_version
from idds.iworkflow.workflow import Workflow
from idds.iworkflow.work import Work


setup_logging(__name__, stream=sys.stdout)


def get_context_args(context, original_args, current_job_kwargs):
    func_name, pre_kwargs, args, kwargs, multi_jobs_kwargs_list = None, None, None, None, None
    if original_args:
        original_args = json_loads(original_args)
        func_name, pre_kwargs, args, kwargs = original_args

    if args:
        args = pickle.loads(zlib.decompress(base64.b64decode(args)))
    if pre_kwargs:
        pre_kwargs = pickle.loads(zlib.decompress(base64.b64decode(pre_kwargs)))
    if kwargs:
        kwargs = pickle.loads(zlib.decompress(base64.b64decode(kwargs)))
    if multi_jobs_kwargs_list:
        multi_jobs_kwargs_list = [pickle.loads(zlib.decompress(base64.b64decode(k))) for k in multi_jobs_kwargs_list]

    if current_job_kwargs:
        if current_job_kwargs == "${IN/L}":
            logging.info("current_job_kwargs == original ${IN/L}, is not set")
        else:
            try:
                # if type(current_job_kwargs) not in (list, tuple):
                #     current_job_kwargs = ast.literal_eval(current_job_kwargs)

                current_job_kwargs = json_loads(current_job_kwargs)

                if current_job_kwargs:
                    current_job_kwargs = pickle.loads(zlib.decompress(base64.b64decode(current_job_kwargs)))

                # current_job_kwargs = current_job_kwargs
                # if current_job_kwargs and isinstance(current_job_kwargs, dict):
                #     # kwargs = merge_dict(kwargs, current_job_kwargs)
                #     kwargs.update(current_job_kwargs)
            except Exception as ex:
                logging.error("Failed to update kwargs: %s" % ex)
    return context, func_name, pre_kwargs, args, kwargs, multi_jobs_kwargs_list, current_job_kwargs


def map_output_files(output_map):
    if not output_map:
        logging.info("No output mapping provided.")
        return

    for org_file, mapped_file in output_map.items():
        if os.path.exists(org_file):
            try:
                os.symlink(org_file, mapped_file)
                logging.info(f"link {org_file} to {mapped_file}")
            except Exception as e:
                logging.error(f"Failed to link {org_file} to {mapped_file}: {e}")
        else:
            logging.warning(f"cannot link {org_file} to {mapped_file}, because it doesn't exist")


def run_workflow(name, key, context, original_args, current_job_kwargs, output_map, inputs, input_map):
    context, func_name, pre_kwargs, args, kwargs, multi_jobs_kwargs_list, current_job_kwargs = get_context_args(context, original_args, current_job_kwargs)
    logging.info(f"name: {name}, key: {key}")
    logging.info("context: %s" % context)
    logging.info("func_name: %s" % func_name)
    logging.info("pre_kwargs: %s" % pre_kwargs)
    logging.info("args: %s" % str(args))
    logging.info("kwargs: %s" % kwargs)
    logging.info("multi_jobs_kwargs_list: %s" % str(multi_jobs_kwargs_list))
    logging.info("current_job_kwargs: %s" % str(current_job_kwargs))
    logging.info(f"output_map: {output_map}")
    logging.info(f"inputs: {inputs}")
    logging.info(f"input_map: {input_map}")

    context.initialize()
    context.setup_source_files()

    workflow = Workflow(func=func_name, pre_kwargs=pre_kwargs, args=args, kwargs=kwargs, multi_jobs_kwargs_list=multi_jobs_kwargs_list, current_job_kwargs=current_job_kwargs, context=context, name=name)
    workflow.load()
    logging.info("workflow: %s" % workflow)
    with workflow:
        ret = workflow.run()
    logging.info("run workflow result: %s" % str(ret))
    if not ret:
        return -1
    return 0


def run_work(name, key, context, original_args, current_job_kwargs, output_map, inputs, input_map, inputs_group):
    context, func_name, pre_kwargs, args, kwargs, multi_jobs_kwargs_list, current_job_kwargs = get_context_args(context, original_args, current_job_kwargs)
    logging.info(f"name: {name}, key: {key}")
    logging.info("context: %s" % context)
    logging.info("func_name: %s" % func_name)
    logging.info("pre_kwargs: %s" % pre_kwargs)
    logging.info("args: %s" % str(args))
    logging.info("kwargs: %s" % kwargs)
    logging.info("multi_jobs_kwargs_list: %s" % str(multi_jobs_kwargs_list))
    logging.info("current_job_kwargs: %s" % str(current_job_kwargs))
    logging.info(f"output_map: {output_map}")
    logging.info(f"inputs: {inputs}")

    context.initialize()
    context.setup_source_files()

    work = Work(func=func_name, pre_kwargs=pre_kwargs, args=args, kwargs=kwargs, multi_jobs_kwargs_list=multi_jobs_kwargs_list, job_key=key,
                current_job_kwargs=current_job_kwargs, context=context, name=name, inputs=inputs, input_map=input_map, inputs_group=inputs_group)
    work.load()
    logging.info("work: %s" % work)
    ret = work.run()
    logging.info("run work result: %s" % str(ret))
    map_output_files(output_map)
    if not ret:
        return -1
    return 0


def run_iworkflow(args, inputs_group):
    if args.context:
        context = decode_base64(args.context)
        context = json_loads(context)
        # logging.info(context)
        # context = str(binascii.unhexlify(args.context).decode())
    else:
        context = None
    if args.original_args:
        original_args = decode_base64(args.original_args)
        # logging.info(original_args)
        # orginal_args = str(binascii.unhexlify(args.original_args).decode())
    else:
        original_args = None
    if args.current_job_kwargs:
        # logging.info(args.current_job_kwargs)
        # current_job_kwargs = str(binascii.unhexlify(args.current_job_kwargs).decode())
        current_job_kwargs = decode_base64(args.current_job_kwargs)
        logging.info(current_job_kwargs)
    else:
        current_job_kwargs = None

    if args.args_file:
        with open(args.args_file, 'r') as file:
            data = file.read()
        args_content = decode_base64(data)
        args_content = json_loads(args_content)
        if 'type' in args_content:
            args.type = args_content['type']
        if 'context' in args_content:
            args.context = args_content['context']
        if 'original_args' in args_content:
            args.original_args = args_content['original_args']
        if 'current_job_kwargs' in args_content:
            args.current_job_kwargs = args_content['current_job_kwargs']

    if args.output_map:
        try:
            if not isinstance(args.output_map, dict):
                args.output_map = ast.literal_eval(args.output_map)
        except Exception as ex:
            logging.warning(f"failed to load args.output_map with ast: {ex}")
        try:
            if not isinstance(args.output_map, dict):
                args.output_map = json.loads(args.output_map)
        except Exception as ex:
            logging.warning(f"failed to load args.output_map with json: {ex}")

    if args.inputs:
        try:
            args.inputs = ast.literal_eval(args.inputs)
        except Exception as ex:
            logging.warning(f"failed to load args.inputs with ast: {ex}")
        try:
            args.inputs = json.loads(args.inputs)
        except Exception as ex:
            logging.warning(f"failed to load args.inputs with json: {ex}")

    if args.input_map:
        try:
            args.input_map = ast.literal_eval(args.input_map)
        except Exception as ex:
            logging.warning(f"failed to load args.input_map with ast: {ex}")
        try:
            args.input_map = json.loads(args.input_map)
        except Exception as ex:
            logging.warning(f"failed to load args.input_map with json: {ex}")

    if args.type == 'workflow':
        logging.info("run workflow")
        password = context.broker_password
        context.broker_password = '***'
        logging.info(f"name: {args.name}, key: {args.key}")
        logging.info("context: %s" % json_dumps(context))
        context.broker_password = password
        logging.info("original_args: %s" % original_args)
        logging.info(f"current_job_kwargs: {type(current_job_kwargs)},{current_job_kwargs}")
        logging.info(f"inputs: {args.inputs}")
        logging.info(f"input_map: {args.input_map}")
        logging.info(f"inputs_group: {inputs_group}")
        exit_code = run_workflow(args.name, args.key, context, original_args, current_job_kwargs, args.output_map, args.inputs, args.input_map, inputs_group)
        logging.info("exit code: %s" % exit_code)
    else:
        logging.info("run work")
        password = context.broker_password
        context.broker_password = '***'
        logging.info(f"name: {args.name}, key: {args.key}")
        logging.info("context: %s" % json_dumps(context))
        context.broker_password = password
        logging.info("original_args: %s" % original_args)
        logging.info(f"current_job_kwargs: {type(current_job_kwargs)},{current_job_kwargs}")
        logging.info(f"inputs: {args.inputs}")
        logging.info(f"input_map: {args.input_map}")
        logging.info(f"inputs_group: {inputs_group}")
        exit_code = run_work(args.name, args.key, context, original_args, current_job_kwargs, args.output_map, args.inputs, args.input_map, inputs_group)
        logging.info("exit code: %s" % exit_code)
    return exit_code


def custom_action():
    class CustomAction(argparse.Action):
        def __init__(self, option_strings, dest, default=False, required=False, help=None):
            super(CustomAction, self).__init__(option_strings=option_strings,
                                               dest=dest, const=True, default=default,
                                               required=required, help=help)

        def __call__(self, parser, namespace, values=None, option_string=None):
            print(values)
            # setattr(namespace, self.dest, values)
    return CustomAction


def parse_item(name, value):
    try:
        ret = value
        ret = ast.literal_eval(ret)
    except Exception as ex:
        logging.warning(f"failed to load {name}: {ret} with ast: {ex}")
    try:
        ret = json.loads(ret)
    except Exception as ex:
        logging.warning(f"failed to load {name}: {ret} with json: {ex}")
    return ret


def parse_unknown_grouped_args(args):
    groups = defaultdict(lambda: {'inputs': [], 'input_map': None})
    pattern = re.compile(r"--(inputs|input_map)(\d+)")
    i = 0
    while i < len(args):
        match = pattern.match(args[i])
        if match:
            key, group_id = match.group(1), int(match.group(2))
            i += 1
            if i < len(args) and not args[i].startswith('--'):
                groups[group_id][key] = args[i]
                i += 1
        else:
            i += 1
    # Convert dict
    rets = {}
    for group in (groups[i] for i in sorted(groups.keys())):
        if group['input_map'] is not None:
            inputs = parse_item('inputs', group['inputs'])
            rets[group['input_map']] = inputs
    return rets


def get_parser():
    """
    Return the argparse parser.
    """
    oparser = argparse.ArgumentParser(prog=os.path.basename(sys.argv[0]), add_help=True)

    # common items
    oparser.add_argument('--version', action='version', version='%(prog)s ' + release_version)
    oparser.add_argument('--verbose', '-v', default=False, action='store_true', help="Print more verbose output.")
    oparser.add_argument('--type', dest='type', action='store', choices=['workflow', 'work'], default='workflow', help='The type in [workflow, work]. Default is workflow.')
    oparser.add_argument('--name', dest='name', help="The name.")
    oparser.add_argument('--key', dest='key', help="The key.")
    oparser.add_argument('--context', dest='context', help="The context.")
    oparser.add_argument('--original_args', dest='original_args', help="The original arguments.")
    oparser.add_argument('--current_job_kwargs', dest='current_job_kwargs', nargs='?', const=None, help="The current job arguments.")
    oparser.add_argument('--args_file', dest='args_file', help="The file with arguments")
    oparser.add_argument('--output', dest='output', help="The output file name.")
    oparser.add_argument('--mapped_output', dest='output', help="The output file name to be mapped.")
    oparser.add_argument('--output_map', dest='output_map', help="The output map.")
    oparser.add_argument('--inputs', dest='inputs', help="The input list.")
    oparser.add_argument('--input_map', dest='input_map', help="The input map.")
    oparser.add_argument('--others', dest='others', nargs='*', help="Additional custom arguments (e.g., key=value).")
    return oparser


if __name__ == '__main__':
    arguments = sys.argv[1:]

    oparser = get_parser()
    # argcomplete.autocomplete(oparser)

    args, unknown = oparser.parse_known_args(arguments)

    try:
        if args.verbose:
            logging.getLogger().setLevel(logging.DEBUG)
        start_time = time.time()

        # Parse dynamic grouped arguments
        inputs_group = parse_unknown_grouped_args(unknown)

        exit_code = run_iworkflow(args, inputs_group)
        end_time = time.time()
        if args.verbose:
            print("Completed in %-0.4f sec." % (end_time - start_time))
        sys.exit(exit_code)
    except Exception as error:
        logging.error("Strange error: {0}".format(error))
        logging.error(traceback.format_exc())
        sys.exit(-1)
