#!python

"""
Script to register Accounting records to SGAS LUTS service.
Intented to be run from cron regularly (every hour or so)

This file is a bit messy, as it contains many things that would normally be
in seperate modules, but is contained in this single file in order to make
deployment easy (no imports, problems setting up PYTHONPATH, etc).

Author: Magnus Jonsson <magnus@hpc2n.umu.se>
Original Author: Henrik Thostrup Jensen <htj@ndgf.org>
Copyright: NorduNET / Nordic Data Grid Facility (2009-2011)
Copyright: HPC2N, Umeå university
"""

import configparser
import logging
import os
import sys
import time
from argparse import ArgumentParser
from urllib.parse import urlparse
from xml.etree import ElementTree as ET

import requests

# Log defaults
LOG_FORMAT = "%(asctime)s [%(levelname)s] %(message)s"

# config sections
CONFIG_SECTION_COMMON = "common"
CONFIG_SECTION_LOGGER = "logger"
CONFIG_HOSTKEY = "x509_user_key"
CONFIG_HOSTCERT = "x509_user_cert"
CONFIG_CERTDIR = "x509_cert_dir"
CONFIG_LOGDIR = "logdir"
CONFIG_LOG_DIR = "log_dir"
CONFIG_LOG_ALL = "log_all"
CONFIG_LOG_VO = "log_vo"
CONFIG_BATCH_SIZE = "batch_size"
CONFIG_RECORD_LIFETIME = "record_lifetime"
CONFIG_LOGFILE = "registrant_logfile"
CONFIG_LOGLEVEL = "loglevel"
CONFIG_STDERR_LEVEL = "stderr_level"
CONFIG_NAMESPACE = "namespace"
CONFIG_RECORDS_TAG = "records_tag"
CONFIG_REGISTRATION_TAG = "registration_tag"
CONFIG_TIMEOUT = "timeout"
# subdirectories in the spool directory
CONFIG_RECORDS_DIRECTORY = "records_directory"
CONFIG_STATE_DIRECTORY = "state_directory"
CONFIG_ARCHIVE_DIRECTORY = "archive_directory"

# system defaults
DEFAULT_CONFIG_FILE = "/etc/sams/sgas-sa-registrant.conf"

# configuration defaults
DEFAULT_LOGFILE = "/var/log/sgas-sa-registration.log"
DEFAULT_HOSTKEY = "/etc/grid-security/hostkey.pem"
DEFAULT_HOSTCERT = "/etc/grid-security/hostcert.pem"
DEFAULT_CERTDIR = None
DEFAULT_LOG_DIR = "/var/spool/softwareaccounting/"
DEFAULT_BATCH_SIZE = 1
DEFAULT_TIMEOUT = "10,10"
DEFAULT_UR_LIFETIME = 30  # days
DEFAULT_STDERR_LEVEL = logging.ERROR
DEFAULT_LOGLEVEL = logging.INFO
DEFAULT_NAMESPACE = "http://sams.snic.se/namespaces/2019/01/softwareaccountingrecords"
DEFAULT_RECORDS_TAG = "SoftwareAccountingRecord"
DEFAULT_REGISTRATION_TAG = "SoftwareAccountingRegistration"

# subdirectories in the spool directory
DEFAULT_RECORDS_DIRECTORY = "records"
DEFAULT_STATE_DIRECTORY = "state"
DEFAULT_ARCHIVE_DIRECTORY = "archive"


# -- code
class StateFile:
    """
    Abstraction for a storage record statefile (describes whereto a record has been registered).
    """

    def __init__(self, logdir, filename, state_directory):
        self.logdir = logdir
        self.filename = filename
        self.state_directory = state_directory

        statefile = self._filepath()
        if os.path.exists(statefile):
            with open(statefile, encoding="utf-8") as f:
                self.urls = set(line.strip() for line in f.readlines() if line.strip())
        else:
            statedir = os.path.join(logdir, self.state_directory)
            if not os.path.exists(statedir):
                os.makedirs(statedir)
            self.urls = set()

    def _filepath(self):
        return os.path.join(self.logdir, self.state_directory, self.filename)

    def __contains__(self, ele):
        return ele in self.urls

    def add(self, ele):
        if ele not in self.urls:
            self.urls.add(ele)
        return self  # makes it possible to do one-liners

    def write(self):
        with open(self._filepath(), "w", encoding="utf-8") as f:
            for url in self.urls:
                f.write(url + "\n")


class ContextFactory:
    """
    SSL context factory. Which hostkey and cert files to use,
    and which CA to load, etc.
    """

    def __init__(self, key_path, cert_path, ca_dir=None):
        self.key_path = key_path
        self.cert_path = cert_path
        self.ca_dir = ca_dir


def getConfigOption(cfg, section, option, default=None):
    def clean(s):
        return isinstance(s, str) and s.strip().replace('"', "").replace("'", "") or s

    value = cfg.get(section, option, fallback=default)
    return clean(value)


def parseLogAll(value):
    return value.split(" ")


def parseLogVO(value):
    vo_regs = {}

    if value is None or len(value) == 0:
        return vo_regs

    pairs = value.split(",")
    for pair in pairs:
        vo_name, url = pair.strip().split(" ", 2)
        vo_regs[vo_name] = url
    return vo_regs


def getVONamesFromUsageRecord(ure, config):
    """
    Return the VO name element values of a usage record.
    """
    # for some reason the followng fails :-/
    # >>> ur.getroot().findall(VO_NAME)
    # so we do it the silly way and iterate over the tree instead.

    vos = []
    for e in ure.getroot():
        if e.tag != config.user_identity:
            continue
        for f in e:
            if f.tag != config.vo:
                continue
            for g in f:
                if g.tag == config.vo_name:
                    vos.append(g.text)
    return vos


def parseTimeout(value):
    (t1, t2) = value.split(",", 2)
    t1 = t1.strip()
    if t1 == "":
        t1 = None
    else:
        t1 = float(t1)
    if t2 is not None:
        t2 = t2.strip()
        if t2 == "":
            t2 = None
        else:
            t2 = float(t2)
    return (t1, t2)


def parseRecordLifeTime(value):
    record_lifetime_days = int(value)
    record_lifetime_seconds = record_lifetime_days * (24 * 60 * 60)
    return record_lifetime_seconds


def createRegistrationPointsMapping(logdir, logpoints_all, logpoints_vo, config):
    """
    Create a mapping from all the usage records filenames to which endpoints they
    should be registered.
    """
    logging.info("Creating registration mapping (may take a little time)")
    mapping = {}

    record_dir = os.path.join(logdir, config.records_directory)
    for filename in os.listdir(record_dir):
        filepath = os.path.join(record_dir, filename)
        # skip if file is not a proper file
        if not os.path.isfile(filepath):
            continue

        try:
            ure = ET.parse(filepath)
        except Exception as e:
            logging.info(f"Error parsing file {filepath}, ({str(e)}) continuing")
            continue

        vos = getVONamesFromUsageRecord(ure, config)

        for lp in logpoints_all:
            mapping.setdefault(lp, []).append(filename)

        for vo in vos:
            vo_lp = logpoints_vo.get(vo)
            if vo_lp:
                mapping.setdefault(vo_lp, []).append(filename)

    return mapping


def createFileEPMapping(epmap):
    """
    creates filename -> [endpoint] map
    makes it easy to know when all registrations have been made for a file
    """
    fnepmap = {}
    for ep, filenames in epmap.items():
        for fn in filenames:
            fnepmap.setdefault(fn, []).append(ep)
    return fnepmap


class HttpRequestException(Exception):
    pass


def httpRequest(url, method="GET", payload=None, ctxFactory=None, timeout=None):
    """
    Peform a http request.
    """
    params = {
        "timeout": timeout,
    }
    if ctxFactory:
        params["cert"] = (ctxFactory.cert_path, ctxFactory.key_path)

        if ctxFactory.ca_dir:
            params["verify"] = ctxFactory.ca_dir

    if payload:
        params["data"] = payload

    response = requests.request(method, url, **params)

    if response.status_code != 200:
        raise HttpRequestException(f"Got a non 200 response from server {response}")

    return response.content


def createEPRegistrationMapping(endpoints, config):
    def createRegistrationURL(location, endpoint):
        if location.startswith("http"):
            # location is a complete url, so we just return it
            return location

        if location.startswith("/"):
            # location is a path, and must be merged with base endpoint to form a suitable url
            url = urlparse(endpoint)
            reg_url = url[0] + "://" + url[1] + location
            return reg_url

        raise ValueError(f"Invalid registration point returned by {endpoint} (got: {location})")

    def gotReply(result, endpoint):
        tree = ET.fromstring(result)
        for service in tree:
            if service.tag == "service":
                found_service = False
                for ele in service:
                    if ele.tag == "name" and ele.text == config.registration_tag:
                        found_service = True
                    elif ele.tag == "href" and found_service is True:
                        location = ele.text
                        return createRegistrationURL(location, endpoint)
        return None  # no registration service found

    logging.info("Retrieving registration hrefs (service endpoints)")

    regmap = {}
    for ep in endpoints:
        try:
            content = httpRequest(ep, ctxFactory=config.context_factory, timeout=config.timeout)
            registration_url = gotReply(content, ep)
            if registration_url is None:
                logging.error(f"Endpoint {ep} does not appear to have a registration service.")
            else:
                regmap[ep] = registration_url
        except Exception as e:
            logging.error(e)

    return regmap


def joinRecordFiles(logdir, filenames, records, records_directory):
    recs = ET.Element(records)

    for fn in filenames:
        rec = ET.parse(os.path.join(logdir, records_directory, fn))
        recs.append(rec.getroot())

    return ET.tostring(recs)


def registerBatch(ep, url, logdir, filenames, config):
    """
    Upload (insert) one or more usage record in a usage record
    service.
    """

    ur_data = joinRecordFiles(logdir, filenames, config.records, config.records_directory)
    try:
        httpRequest(url, method="POST", payload=ur_data, ctxFactory=config.context_factory, timeout=config.timeout)
        logging.info(f"{len(filenames)} records registered to {ep}")
        for fn in filenames:
            StateFile(logdir, fn, config.state_directory).add(ep).write()

    except Exception as e:
        logging.error(f"Error during batch insertion: {str(e)}")
        logging.debug(e)
        raise e


def registerUsageRecords(mapping, logdir, config):
    """
    Register usage records, given a mapping of where to
    register the usage records.
    """
    urmap = createFileEPMapping(mapping)
    if not urmap:  # no registration to perform
        logging.info("No registrations to perform")
        return

    logging.info(f"Registrations to perform: {len(urmap)} files")
    regmap = createEPRegistrationMapping(mapping.keys(), config)
    performURRegistration(regmap, urmap, logdir, config)
    archiveUsageRecords(logdir, urmap, config)


def performURRegistration(regmap, urmap, logdir, config):
    if not regmap:
        logging.error("Failed to get any service refs, not doing any registrations")
        return

    batch_sets = {}
    for ep, urreg in regmap.items():
        logging.info(f"{ep} -> {urreg}")
        batch_sets[ep] = []

    logging.info("Starting registration")

    skipped_registrations = {}

    # new registration logic (batching)
    for filename, endpoints in urmap.items():
        state = StateFile(logdir, filename, config.state_directory)
        for ep in endpoints:
            if ep in state:
                skipped_registrations[ep] = skipped_registrations.get(ep, 0) + 1
                continue
            try:
                batch_sets[ep].append(filename)
            except KeyError:
                pass  # deferring registration as service is not available

    for ep, ur_registered in skipped_registrations.items():
        logging.info(f"Skipping {ur_registered} registrations to {ep}, records already registered")

    # build up registraion batches (list of (ep, filenames) tuples)
    registrations = []
    for ep, filenames in batch_sets.items():
        registrations += [(ep, filenames[i : i + config.batch_size]) for i in range(0, len(filenames), config.batch_size)]

    error_endpoints = {}

    for service_endpoint, filenames in registrations:
        if service_endpoint in error_endpoints:
            continue

        try:
            registerBatch(service_endpoint, regmap[service_endpoint], logdir, filenames, config)
        except Exception as e:
            logging.error(f"Error registration records to {service_endpoint}")
            logging.error("Skipping all registrations to this endpoint for now")
            logging.debug(e)
            error_endpoints[service_endpoint] = True


def archiveUsageRecords(logdir, urmap, config):
    logging.info("Registration done, commencing archiving process")
    archive_dir = os.path.join(logdir, config.archive_directory)
    if not os.path.exists(archive_dir):
        os.makedirs(archive_dir)

    for filename, endpoints in urmap.items():
        state = StateFile(logdir, filename, config.state_directory)
        for ep in endpoints:
            if ep not in state:
                break
        else:
            urfilepath = os.path.join(logdir, config.records_directory, filename)
            statefilepath = os.path.join(logdir, config.state_directory, filename)
            archivefilepath = os.path.join(logdir, config.archive_directory, filename)
            os.unlink(statefilepath)
            os.rename(urfilepath, archivefilepath)

    logging.info("Archiving done")


def deleteOldUsageRecords(log_dir, ttl_seconds, archive_directory):
    archive_dir = os.path.join(log_dir, archive_directory)
    logging.info("Cleaning up old records.")

    now = time.time()

    i = 0
    for filename in os.listdir(archive_dir):
        filepath = os.path.join(archive_dir, filename)
        # skip if file is not a proper file
        if not os.path.isfile(filepath):
            continue

        # use ctime to determine file age
        f_ctime = os.stat(filepath).st_ctime

        if f_ctime + ttl_seconds < now:
            # file is old, will get deleted
            os.unlink(filepath)
            i += 1

    logging.info(f"Records deleted: {i}")


def getOptions():
    """Handle command line arguments"""
    parser = ArgumentParser()
    parser.add_argument("-l", "--log-file", dest="logfile", help="Log file (overwrites config option).")
    parser.add_argument(
        "-d",
        "--debug",
        action="store_true",
        default=False,
        help="Set log level to DEBUG",
    )
    parser.add_argument(
        "-c",
        "--config",
        "--config-file",
        dest="config",
        help="Configuration file.",
        default=DEFAULT_CONFIG_FILE,
        metavar="FILE",
    )
    parser.add_argument("-s", "--stdout", action="store_true", default=False, help="Log to stdout")
    return parser.parse_args()


def setupLogging(options, cfg):
    # Log level
    if options.debug:
        loglevel = logging.DEBUG
    else:
        loglevel = getConfigOption(cfg, CONFIG_SECTION_LOGGER, CONFIG_LOGLEVEL, DEFAULT_LOGLEVEL)

    # open logfile
    if options.stdout:
        logging.basicConfig(
            stream=sys.stdout,
            format=LOG_FORMAT,
            level=loglevel,
        )
    else:
        logfile = options.logfile
        if logfile is None:
            logfile = getConfigOption(cfg, CONFIG_SECTION_LOGGER, CONFIG_LOGFILE, DEFAULT_LOGFILE)
        logging.basicConfig(
            filename=logfile,
            format=LOG_FORMAT,
            level=loglevel,
        )


class Config:
    """
    Collection of config parameters.
    """

    def __init__(
        self,
        context_factory,
        batch_size,
        registration_tag,
        namespace,
        records_tag,
        timeout,
        records_directory,
        state_directory,
        archive_directory,
        user_identity,
        vo,
        vo_name,
    ):
        self.context_factory = context_factory
        self.batch_size = batch_size
        self.registration_tag = registration_tag
        self.records = ET.QName(f"{{{namespace}}}{records_tag}")
        self.timeout = timeout
        self.records_directory = records_directory
        self.state_directory = state_directory
        self.archive_directory = archive_directory
        self.user_identity = ET.QName(f"{{{namespace}}}{user_identity}")
        self.vo = ET.QName(f"{{{namespace}}}{vo}")
        self.vo_name = ET.QName(f"{{{namespace}}}{vo_name}")


class ConfigurationError(Exception):
    pass


def main():
    """
    Parse command line, setup logging, start the actual logic, etc.
    """

    # start by parsing the command line to see if we have a specific config file
    options = getOptions()

    cfg_file = options.config
    if (not os.path.exists(cfg_file)) or (not os.path.isfile(cfg_file)):
        raise ConfigurationError(f"The config file '{cfg_file}' does not exist or is not a file")

    # read config
    cfg = configparser.ConfigParser()
    cfg.read(cfg_file)

    # Setup Logging
    setupLogging(options, cfg)

    # Certificates
    host_key = getConfigOption(cfg, CONFIG_SECTION_COMMON, CONFIG_HOSTKEY, DEFAULT_HOSTKEY)
    host_cert = getConfigOption(cfg, CONFIG_SECTION_COMMON, CONFIG_HOSTCERT, DEFAULT_HOSTCERT)
    cert_dir = getConfigOption(cfg, CONFIG_SECTION_COMMON, CONFIG_CERTDIR, DEFAULT_CERTDIR)

    # Where are records stored?
    log_dir = getConfigOption(
        cfg,
        CONFIG_SECTION_COMMON,
        CONFIG_LOGDIR,
        getConfigOption(cfg, CONFIG_SECTION_COMMON, CONFIG_LOG_DIR, DEFAULT_LOG_DIR),
    )

    # Logger options
    las = getConfigOption(cfg, CONFIG_SECTION_LOGGER, CONFIG_LOG_ALL)
    lvo = getConfigOption(cfg, CONFIG_SECTION_LOGGER, CONFIG_LOG_VO)
    rlt = getConfigOption(cfg, CONFIG_SECTION_LOGGER, CONFIG_RECORD_LIFETIME, DEFAULT_UR_LIFETIME)
    log_all = parseLogAll(las)
    log_vo = parseLogVO(lvo)
    record_lifetime = parseRecordLifeTime(rlt)

    timeout = getConfigOption(cfg, CONFIG_SECTION_LOGGER, CONFIG_TIMEOUT, DEFAULT_TIMEOUT)
    batch_size = getConfigOption(cfg, CONFIG_SECTION_LOGGER, CONFIG_BATCH_SIZE, DEFAULT_BATCH_SIZE)

    # XML Tags and Namespaces
    registration_tag = getConfigOption(cfg, CONFIG_SECTION_LOGGER, CONFIG_REGISTRATION_TAG, DEFAULT_REGISTRATION_TAG)
    records_tag = getConfigOption(cfg, CONFIG_SECTION_LOGGER, CONFIG_RECORDS_TAG, DEFAULT_RECORDS_TAG)
    namespace = getConfigOption(cfg, CONFIG_SECTION_LOGGER, CONFIG_NAMESPACE, DEFAULT_NAMESPACE)

    records_directory = getConfigOption(cfg, CONFIG_SECTION_LOGGER, CONFIG_RECORDS_DIRECTORY, DEFAULT_RECORDS_DIRECTORY)
    state_directory = getConfigOption(cfg, CONFIG_SECTION_LOGGER, CONFIG_STATE_DIRECTORY, DEFAULT_STATE_DIRECTORY)
    archive_directory = getConfigOption(cfg, CONFIG_SECTION_LOGGER, CONFIG_ARCHIVE_DIRECTORY, DEFAULT_ARCHIVE_DIRECTORY)

    config = Config(
        context_factory=ContextFactory(host_key, host_cert, cert_dir),
        registration_tag=registration_tag,
        records_tag=records_tag,
        namespace=namespace,
        timeout=parseTimeout(timeout),
        batch_size=int(batch_size),
        records_directory=records_directory,
        state_directory=state_directory,
        archive_directory=archive_directory,
        user_identity="UserIdentity",
        vo="VO",
        vo_name="Name",
    )

    logging.info("Configuration:")
    logging.info(f" Log dir: {log_dir}")
    logging.info(f" Log all: {log_all}")
    logging.debug(f" Host key  : {host_key}")
    logging.debug(f" Host cert : {host_cert}")
    logging.debug(f" Cert dir  : {cert_dir}")

    if not log_all:
        logging.error("No log points given. Cowardly refusing to do anything")
        sys.exit(1)

    if not os.path.exists(log_dir):
        logging.error(f"Log directory {log_dir} does not exist, bailing out.")
        sys.exit(1)

    mapping = createRegistrationPointsMapping(log_dir, log_all, log_vo, config)
    registerUsageRecords(mapping, log_dir, config)
    deleteOldUsageRecords(log_dir, record_lifetime, config.archive_directory)


if __name__ == "__main__":
    try:
        main()
    except Exception as error:
        logging.exception(error)
        sys.exit(1)
