#!/usr/bin/env python
#
# Licensed under the Apache License, Version 2.0 (the "License");
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0OA
#
# Authors:
# - Wen Guan, <wen.guan@cern.ch>, 2024 - 2025


"""
Run workflow.
"""

from __future__ import print_function

import argparse
import base64
import logging
import json
import os
import sys
import time
import traceback


logging.basicConfig(stream=sys.stderr,
                    level=logging.DEBUG,
                    format='%(asctime)s\t%(threadName)s\t%(name)s\t%(levelname)s\t%(message)s')
logging.Formatter.converter = time.gmtime


def get_parser(program):
    """
    Return the argparse parser.
    """
    oparser = argparse.ArgumentParser(prog=os.path.basename(program), add_help=True)

    # common items
    oparser.add_argument('--pre_setup', dest='pre_setup', help="The pre_setup.")
    oparser.add_argument('--setup', dest='setup', help="The setup.")
    # oparser.add_argument('--output_map', dest='output_map', help="The output map.")
    # oparser.add_argument('--inputs', dest='inputs', help="The input list.")
    # oparser.add_argument('--input_map', dest='input_map', help="The input map.")
    oparser.add_argument('run_args', nargs=argparse.REMAINDER, help="All other arguments")
    return oparser


def get_workflow_setup_script():
    script = """#!/bin/bash

echo "current dir: " $PWD

which python
which python3

# cd ${current_dir}

# if it's in a container, this part is needed again to setup the environment.
current_dir=$PWD
export PATH=${current_dir}:${current_dir}/tmp_bin:${current_dir}/bin:$PATH
export PYTHONPATH=${current_dir}:${current_dir}/lib_py:$PYTHONPATH

if ! command -v python &> /dev/null
then
    echo "no python, link python3 to python"
    # alias python=python3
    ln -fs $(which python3) ./python
fi

if [ -f ${current_dir}/x509_proxy ]; then
    export X509_USER_PROXY=${current_dir}/x509_proxy
fi
"""
    return script


# This part is possbile to run on SL7 with python2
def decode_base64(sb, remove_quotes=False):
    try:
        if isinstance(sb, str):
            if sys.version_info.major == 2:
                # In python 2, str is already bytes
                sb_bytes = sb
            else:
                sb_bytes = bytes(sb, 'ascii')
        elif isinstance(sb, bytes):
            sb_bytes = sb
        else:
            return sb
        decode_str = base64.b64decode(sb_bytes).decode("utf-8")
        # remove the single quotes afeter decoding
        if remove_quotes:
            return decode_str[1:-1]
        return decode_str
    except Exception as ex:
        logging.error("decode_base64 %s: %s" % (sb, ex))
        return sb


def create_run_workflow_cmd(run_args):
    # current_dir = os.getcwd()
    run_script = "./run_workflow.sh"
    # cmd = os.path.join(current_dir, run_script)
    setup_script = get_workflow_setup_script()
    # script = setup_script + "\n" + " ".join(run_args)
    script = setup_script + "\n"
    script += " ".join(f'"{str(arg)}"' for arg in run_args)

    logging.debug("script: ")
    logging.debug(script)

    with open(run_script, 'w') as f:
        f.write(script)
    os.chmod(run_script, 0o755)
    return run_script


def process_args(args):
    logging.debug("pre_setup:")
    logging.debug(args.pre_setup)
    logging.debug("setup: ")
    logging.debug(args.setup)
    # logging.debug("output_map: ")
    # logging.debug(args.output_map)
    # logging.debug("inputs: ")
    # logging.debug(args.inputs)
    logging.debug("run_args:")
    logging.debug(args.run_args)

    cmd = ""
    if args.pre_setup:
        pre_setup = json.loads(decode_base64(args.pre_setup, remove_quotes=False))
        if pre_setup:
            cmd = cmd + pre_setup
    if args.setup:
        setup = json.loads(decode_base64(args.setup, remove_quotes=False))
        if setup:
            cmd = cmd + " " + setup

    run_script = create_run_workflow_cmd(args.run_args)

    cmd = cmd + " " + run_script
    return cmd


if __name__ == '__main__':
    arguments = sys.argv[1:]

    oparser = get_parser(sys.argv[0])
    # argcomplete.autocomplete(oparser)

    # logging.debug("all args:")
    # logging.debug(sys.argv)
    logging.debug("arguments: ")
    logging.debug(arguments)
    args = oparser.parse_args(arguments)

    try:
        start_time = time.time()
        new_command = process_args(args)
        print(new_command)
        end_time = time.time()
        logging.info("Completed processing args in %-0.4f sec." % (end_time - start_time))
        sys.exit(0)
    except Exception as error:
        logging.error("Strange error: {0}".format(error))
        logging.error(traceback.format_exc())
        sys.exit(-1)
