#!python
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2022 fjord-technologies
# SPDX-License-Identifier: GPL-3.0-or-later
"""auton"""

__version__ = '0.2.22'

import argparse
import os
import shlex
import sys
import time
import uuid

import logging
from logging.handlers import WatchedFileHandler

from dotenv.main import dotenv_values
from sonicprobe import helpers
from sonicprobe.libs import urisup

import requests
from requests import exceptions

import six

SYSLOG_NAME     = "auton"
LOG             = logging.getLogger(SYSLOG_NAME)

DEFAULT_LOGFILE = "/var/log/auton/auton.log"
DEFAULT_DELAY   = 0.5

AUTON_LOGFILE   = os.environ.get('AUTON_LOGFILE') or DEFAULT_LOGFILE


def argv_parse_check():
    """
    Parse (and check a little) command line parameters
    """
    parser        = argparse.ArgumentParser()

    parser.add_argument("-a",
                        action    = 'append',
                        dest      = 'args',
                        type      = six.ensure_text,
                        default   = [],
                        help      = "Passed arguments")
    parser.add_argument("-A",
                        action    = 'append',
                        dest      = 'argfiles',
                        type      = six.ensure_text,
                        default   = [],
                        help      = "Passed argument files")
    parser.add_argument("--multi-args",
                        action    = 'append',
                        dest      = 'margs',
                        type      = six.ensure_text,
                        default   = [],
                        help      = "Passed multiple arguments (arguments splitted)")
    parser.add_argument("--multi-argsfiles",
                        action    = 'append',
                        dest      = 'margsfiles',
                        type      = six.ensure_text,
                        default   = [],
                        help      = "Passed multiple arguments files (arguments splitted)")
    parser.add_argument("--uri",
                        action    = 'append',
                        dest      = 'uri',
                        type      = six.ensure_text,
                        default   = [],
                        help      = "Auton daemon URI addresses")
    parser.add_argument("--uid",
                        dest      = 'uid',
                        type      = six.ensure_text,
                        default   = os.environ.get('AUTON_UID') or ("%s" % uuid.uuid4()),
                        help      = "Auton uid")
    parser.add_argument("--endpoint",
                        dest      = 'endpoint',
                        type      = six.ensure_text,
                        default   = os.environ.get('AUTON_ENDPOINT'),
                        help      = "Auton endpoint")
    parser.add_argument("--auth-user",
                        dest      = 'auth_user',
                        type      = six.ensure_text,
                        default   = os.environ.get('AUTON_AUTH_USER'),
                        help      = "Auton auth user")
    parser.add_argument("--auth-passwd",
                        dest      = 'auth_passwd',
                        type      = six.ensure_text,
                        default   = os.environ.get('AUTON_AUTH_PASSWD'),
                        help      = "Auton auth password")
    parser.add_argument("--delay",
                        dest      = 'delay',
                        type      = float,
                        default   = DEFAULT_DELAY,
                        help      = "Delay between requests instead of %(default)s")
    parser.add_argument("--mode",
                        dest      = 'mode',
                        default   = 'autorun',
                        choices   = ('autorun', 'run', 'status'),
                        help      = "Auton mode: autorun, run, status, instead of %(default)s")
    parser.add_argument("-e",
                        action    = 'append',
                        dest      = 'envvars',
                        type      = six.ensure_text,
                        default   = [],
                        help      = "Passed environment variables")
    parser.add_argument("--envfile",
                        action    = 'append',
                        dest      = 'envfiles',
                        type      = six.ensure_text,
                        default   = [],
                        help      = "Passed envfile parameters")
    parser.add_argument("--imp-env",
                        action    = 'append',
                        dest      = 'imp_envvars',
                        type      = six.ensure_text,
                        default   = [],
                        help      = "Import existing environment variables")
    parser.add_argument("--load-envfile",
                        action    = 'append',
                        dest      = 'load_envfiles',
                        type      = six.ensure_text,
                        default   = [],
                        help      = "Load environment variables from file")
    parser.add_argument("-l",
                        dest      = 'loglevel',
                        default   = 'info',   # warning: see affectation under
                        choices   = ('critical', 'error', 'warning', 'info', 'debug'),
                        help      = ("Emit traces with LOGLEVEL details, must be one of:\t"
                                     "critical, error, warning, info, debug"))
    parser.add_argument("--logfile",
                        dest      = 'logfile',
                        type      = six.ensure_text,
                        default   = AUTON_LOGFILE,
                        help      = "Use log file <logfile> instead of %(default)s")
    parser.add_argument("--no-return-code",
                        action    = 'store_true',
                        dest      = 'no_return_code',
                        default   = helpers.boolize(os.environ.get('AUTON_NO_RETURN_CODE', False)),
                        help      = "Do not exit with return code if present")

    sys.argv      = helpers.escape_parse_args(('-a',
                                               '-A',
                                               '--multi-args',
                                               '--multi-argsfiles'),
                                              sys.argv)

    options, args = parser.parse_known_args() # pylint: disable=unused-variable
    options.loglevel = getattr(logging, options.loglevel.upper(), logging.INFO)

    if not options.uri and os.environ.get('AUTON_URI'):
        options.uri = [x.strip() for x in os.environ.get('AUTON_URI').split(',')]

    return options


class AutonClient(object): # pylint: disable=useless-object-inheritance
    def __init__(self, options):
        self.options  = options
        self.envvars  = {}
        self.argfiles = []
        self.uri      = None

        if not self.options.uri:
            raise ValueError("missing variable AUTON_URI")

        if not self.options.uid:
            raise ValueError("missing variable AUTON_UID")

        if not self.options.endpoint:
            raise ValueError("missing variable AUTON_ENDPOINT")

        self._parse_multi_args()
        self._parse_multi_argsfiles()
        self._import_envvars()
        self._load_envfiles()
        self._parse_envvars()
        self._parse_argfiles()
        self._auth = None
        if self.options.auth_user:
            self._auth = (self.options.auth_user,
                          self.options.auth_passwd or '')

    @staticmethod
    def _check_results(res = None):
        if not res:
            raise LookupError("unknown error")

        if res.get('message'):
            raise LookupError("invalid request, message: %s" % res['message'])

        if res.get('code'):
            return res

        raise LookupError("errors occurred: %r" % res)

    @staticmethod
    def _show_results(data):
        if data.get('stream'):
            for x in data['stream']:
                sys.stdout.write(x)

        if data['status'] == 'complete' \
           and data.get('errors'):
            for x in data['errors']:
                sys.stderr.write(x)

    def _import_envvars(self):
        for envvar in self.options.imp_envvars:
            if envvar in os.environ:
                self.envvars[envvar] = os.environ[envvar]
            else:
                LOG.debug("unable to find environment variable: %r", envvar)

    def _load_envfiles(self):
        for envfile in self.options.load_envfiles:
            self.envvars.update(dotenv_values(envfile))

    def _parse_envvars(self):
        for envvar in self.options.envvars:
            env = envvar.split('=', 1)
            if not env[0]:
                LOG.warning("invalid environment variable: %r", env)
            elif len(env) == 1:
                self.envvars[env[0]] = ''
            else:
                self.envvars[env[0]] = env[1]

    def _parse_multi_args(self):
        for margs in self.options.margs:
            args = shlex.split(margs)
            if args:
                self.options.args.extend(args)

    def _parse_multi_argsfiles(self):
        for margsfiles in self.options.margsfiles:
            argsfiles = shlex.split(margsfiles)
            if argsfiles:
                self.options.argfiles.extend(argsfiles)

    def _parse_argfiles(self):
        for argfile in self.options.argfiles:
            arg = argfile.split('=', 1)
            if not arg[0] or arg[0] == '@' or len(arg) == 1:
                LOG.warning("invalid argument file name: %r", arg)
            elif arg[1] == '-':
                data = six.BytesIO(helpers.read_large_file(sys.stdin))
                self.argfiles.append({'arg': arg[0],
                                      'content': helpers.base64_encode_file(data),
                                      'filename': ''})
                data = None
            elif not os.path.isfile(arg[1]):
                LOG.warning("unable to find file: %r", arg[1])
            else:
                self.argfiles.append({'arg': arg[0],
                                      'content': helpers.base64_encode_file(arg[1]),
                                      'filename': os.path.basename(arg[1])})

    def _build_uri(self, uri, method):
        r    = list(urisup.uri_help_split(uri))
        r[2] = "/%s/%s/%s" % (method, self.options.endpoint, self.options.uid)

        return urisup.uri_help_unsplit(r)

    @staticmethod
    def _build_headers(headers = None):
        if not headers:
            headers = {}

        headers['User-Agent'] = ("%s/%s" % (SYSLOG_NAME, __version__))

        return headers

    def do_run(self):
        req      = None
        self.uri = None

        try:
            for uri in self.options.uri:
                try:
                    req  = requests.post(self._build_uri(uri, 'run'),
                                         auth = self._auth,
                                         headers = self._build_headers(),
                                         json = {'args':     self.options.args,
                                                 'argfiles': self.argfiles,
                                                 'env':      self.envvars,
                                                 'envfiles': self.options.envfiles})
                    self.uri = uri
                except exceptions.ConnectionError:
                    continue

            if not self.uri:
                raise exceptions.ConnectionError("unable to connect autond")

            if not req.text:
                return self._check_results()

            try:
                rs = req.json()
            except ValueError:
                req.raise_for_status()
                return self._check_results()

            return self._check_results(rs)
        finally:
            if req:
                req.close()

    def do_status(self):
        req = None

        try:
            if not self.uri:
                for uri in self.options.uri:
                    try:
                        req = requests.get(self._build_uri(uri, 'status'),
                                           auth = self._auth,
                                           headers = self._build_headers())
                        self.uri = uri
                    except exceptions.ConnectionError:
                        continue

                if not self.uri:
                    raise exceptions.ConnectionError("unable to connect autond")
            else:
                req = requests.get(self._build_uri(self.uri, 'status'),
                                   auth = self._auth,
                                   headers = self._build_headers())

            if not req.text:
                return self._check_results()

            try:
                rs = req.json()
            except ValueError:
                req.raise_for_status()
                return self._check_results()

            return self._check_results(rs)
        finally:
            if req:
                req.close()

    def do_autorun(self):
        data = self.do_run()
        while data['status'] != 'complete':
            time.sleep(self.options.delay)
            data = self.do_status()
            self._show_results(data)

        if self.options.no_return_code or data['return_code'] is None:
            return int(data['code'] != 200)

        return data['return_code']


def main(options):
    """
    Main function
    """
    xformat = "%(levelname)s:%(asctime)-15s: %(message)s"
    datefmt = '%Y-%m-%d %H:%M:%S'
    logging.basicConfig(level   = options.loglevel,
                        format  = xformat,
                        datefmt = datefmt)

    logdir  = os.path.dirname(options.logfile)
    if os.path.isdir(logdir) and os.access(logdir, os.W_OK):
        filehandler = WatchedFileHandler(options.logfile)
        filehandler.setFormatter(logging.Formatter(xformat,
                                                   datefmt = datefmt))
        root_logger = logging.getLogger('')
        root_logger.addHandler(filehandler)

    auton   = None
    rc      = 0

    try:
        auton = AutonClient(options)
        rc    = getattr(auton, "do_%s" % options.mode)()
    except (KeyboardInterrupt, SystemExit):
        rc = 2
    except Exception as e:
        LOG.error(e)
        rc = 1
        raise
    finally:
        sys.exit(rc)


if __name__ == '__main__':
    main(argv_parse_check())
