#!/usr/bin/env python

"""inventory_ad
===============

This script is an Ansible Dynamic Inventory script. It takes the host name of a
client, searches for it in MS Active Directory and returns it's groups. The
groups are filtered by prefix and Active Directory Organizational Unit.

This script returns the group and hostname in an Ansible compatible JSON
format. The script writes an offline cache file for later use if the Active
Directory query succeeds. This cache is returned if the Active Directory query
fails.

It always returns the common group for common tasks. It also returns the group
common-ad-bound if the result comes from an online query rather than from the
cache.
"""
from __future__ import print_function

__author__ = "Jan Welker"
__email__ = "jan.welker@unibas.ch"
__copyright__ = "Copyright 2017, University of Basel"

__credits__ = ["Balz Aschwanden", "Jan Welker"]
__license__ = "GPL"

from json import dumps
from os import path
from socket import gethostname
from ssl import CERT_REQUIRED

from antslib import configer, logger
from ldap3 import NTLM, Connection, Server, Tls, core


def connect_to_ad(ldap_user, ldap_pw, ldap_host):
    """Connect to Active Directory and return the connection or None."""
    tls = Tls(validate=CERT_REQUIRED)
    server = Server(ldap_host, use_ssl=True, tls=tls)
    connection = Connection(
        server, user=ldap_user, password=ldap_pw, authentication=NTLM
    )
    try:
        connection.bind()

        result = connection.result["description"]
        if result == "success":
            return connection
        else:
            return None
    except core.exceptions.LDAPSocketOpenError:
        return None


def get_simple_host_name(fqdn):
    """Convert FQDN to simple host name and return it."""
    simple_hostname = fqdn.split(".")[0]
    return simple_hostname


def host_exist_in_ad(connection, simple_hostname, ldap_ou):
    """Check if host can be found in Active Directory. The host does not have
       to be bound to AD it just has to exist."""
    connection.search(ldap_ou, "(cn=%s)" % simple_hostname, attributes=["cn"])
    try:
        response = connection.response[0]["attributes"]["cn"]
    except KeyError as error:
        logger.logfile_logger.info(
            "Host %s not found in %s: %s" % (simple_hostname, ldap_ou, error)
        )
        response = ""

    return bool(response)


def get_computer_dn(connection, simple_hostname, ldap_ou):
    """Take the simple host name and return it's distinguished name"""
    connection.search(
        ldap_ou, "(cn=%s)" % simple_hostname, attributes=["distinguishedName"]
    )
    try:
        response = connection.response[0]["attributes"]["distinguishedName"]
    except KeyError as error:
        logger.logfile_logger.info(
            "DN of %s not found in %s: %s" % simple_hostname, ldap_ou, error
        )
        response = ""
    return response


def get_computer_groups(connection, search_base, computer_dn, group_prefix):
    """Receive groups that the computer object is a member of.
    member:1.2.840.113556.1.4.1941:=%s is a special Active Directory OID that
    returns nested groups and not just the first level. The result is filtered
    by group prefix and Organizational Unit """

    result = list()
    connection.search(
        search_base,
        "(member:1.2.840.113556.1.4.1941:=%s)" % computer_dn,
        attributes=["cn"],
    )
    response = connection.response
    if response:
        # Extracting groups from ldap response
        for group in response:
            try:
                group_name = group["attributes"]["cn"].lower()
                # Only adding groups that start with the prefix
                if group_name.startswith(group_prefix):
                    if group_name not in result:
                        result.append(group_name)
            except KeyError:
                logger.logfile_logger.info(
                    "No groups found for %s in %s with preffix %s"
                    % (computer_dn, search_base, group_prefix)
                )
    return result


def format_output(output):
    """Return results in Ansible JSON syntax.

    Ansible requirements are documented here:
    http://docs.ansible.com/ansible/latest/dev_guide/developing_inventory.html
    """
    return dumps(output, sort_keys=True, indent=4, separators=(",", ": "))


def write_cache(cache_file, output):
    """Write inventory cache to file."""
    try:
        with open(cache_file, "w") as cache:
            for line in format_output(output):
                cache.write(line)
    except IOError as error:
        logger.console_logger.error("Error while writing cache: %s" % error)
        logger.console_logger.error(
            "Make sure the base process has the right permissions and path exists for %s"
            % cache_file
        )
        raise


def read_cache(cache_file):
    """Read cache file and return content."""
    if not path.isfile(cache_file):
        return False
    with open(cache_file, "r") as cache:
        return cache.read()


def main():
    """Fetching groups from AD and printing them in JSON."""
    cfg = configer.read_config("ad")
    cache_file = cfg["cache_file"]

    # Reading fully qualified host name and converting it to lower case
    fqdn = gethostname()
    simple_host_name = get_simple_host_name(fqdn.lower())

    # Connecting to Active Directory and check connection status
    ad_connection = connect_to_ad(cfg["ldap_user"], cfg["ldap_pw"], cfg["ldap_host"])
    online = bool(ad_connection)
    if online:
        logger.logfile_logger.info("Using online results from AD")
        # Initializing output
        output = dict()
        output[cfg["common_group"]] = [fqdn]

        # Checking if host is in AD
        if host_exist_in_ad(ad_connection, simple_host_name, cfg["ldap_ou_computers"]):
            # Looking up computers distinguished name
            computer_dn = get_computer_dn(
                ad_connection, simple_host_name, cfg["ldap_ou_computers"]
            )

            # Looking up computers groups
            computer_groups = get_computer_groups(
                ad_connection, cfg["ldap_ou_groups"], computer_dn, cfg["group_prefix"]
            )
            # Adding groups to output
            for group in computer_groups:
                output[group] = [fqdn]

        # Writing output to cache file
        try:
            write_cache(cache_file, output)
        except IOError as error:
            logger.console_logger.error("Error while writing cache: %s" % error)

        # Adding online Group after cache is written.
        # We do not want to cache this group
        output[cfg["common_ad_connected"]] = [fqdn]

        # Printing output in Ansible JSON syntax
        print(format_output(output))

    # We are not bound to AD we are offline
    else:
        logger.logfile_logger.info("Using cached results from AD")
        # Reading cache file
        cached_output = read_cache(cache_file)
        if cached_output:
            # Printing cached
            print(cached_output)
        else:
            # Printing default group
            output = dict()
            output[cfg["common_group"]] = [fqdn]
            print(format_output(output))


if __name__ == "__main__":
    main()
