#!/usr/bin/env python

"""inventory_ad
===============


This script is an Ansible Dynamic Inventory script. It takes the host name of a
client, searches for it in MS Active Directory and returns it's groups. The
groups are filtered by prefix and Active Directory Organizational Unit.

This script returns the group and hostname in an Ansible compatible JSON
format.  The script writes an offline cache file for later use if the Active
Directory query succeeds. This cache is returned if the Active Directory query
fails.

It always returns the common group for common tasks. It also returns the group
common-ad-bound if the result comes from an online query rather than from the
cache.
"""


__author__ = "Jan Welker"
__email__ = "jan.welker@unibas.ch"
__copyright__ = "Copyright 2017, University of Basel"

__credits__ = ["Balz Aschwanden", "Jan Welker"]
__license__ = "GPL"


from ssl import CERT_REQUIRED
from os import path
from socket import gethostname
from json import dumps
from ldap3 import Server, Connection, Tls, NTLM


from antslib import configer
from antslib import logger


def connect_to_ad(ldap_user, ldap_pw, ldap_host):
    """Connect to Active Directory and return the connection or None."""
    tls = Tls(validate=CERT_REQUIRED)
    server = Server(ldap_host, use_ssl=True, tls=tls)
    connection = Connection(server, user=ldap_user, password=ldap_pw,
                            authentication=NTLM)
    connection.bind()

    result = connection.result['description']
    if result == 'success':
        return connection
    else:
        return None


def get_simple_host_name(fqdn):
    """Convert FQDN to simple host name and return it."""
    simple_hostname = fqdn.split('.')[0]
    return simple_hostname


def host_exist_in_ad(connection, simple_hostname, ldap_ou):
    """Check if host can be found in Active Directory. The host does not have
       to be bound to AD it just has to exist."""
    connection.search(ldap_ou, '(cn=%s)' % simple_hostname,
                      attributes=['cn'])
    try:
        response = connection.response[0]['attributes']['cn']
    except KeyError:
        response = ''

    return bool(response)


def get_computer_dn(connection, simple_hostname, ldap_ou):
    """Take the simple host name and return it's distinguished name"""
    connection.search(ldap_ou, '(cn=%s)' % simple_hostname,
                      attributes=['distinguishedName'])
    try:
        response = connection.response[0]['attributes']['distinguishedName']
    except KeyError:
        response = ''
    return response


def get_computer_groups(connection, search_base, computer_dn, group_prefix):
    """Receive groups that the computer object is a member off.
    member:1.2.840.113556.1.4.1941:=%s is a special Active Directory OID that
    returns nested groups and not just the first level. The result is filtered
    by group prefix and Organizational Unit """

    result = list()
    connection.search(search_base,
                      '(member:1.2.840.113556.1.4.1941:=%s)' % computer_dn,
                      attributes=['cn'])
    response = connection.response
    if response:
        # Removing LDAP string from last line
        response.pop()
        # Extracting groups from ldap response
        for group in response:
            group_name = group['attributes']['cn']
            # Only adding groups that start with the prefix
            if group_prefix in group_name:
                result.append(group_name)
    return result


def format_output(output):
    """Return results in Ansible JSON syntax.

    Ansible requirements are documented here:
    http://docs.ansible.com/ansible/latest/dev_guide/developing_inventory.html
    """
    return dumps(output, sort_keys=True, indent=4, separators=(',', ': '))


def write_cache(cache_file, output):
    """Write inventory cache to file."""
    try:
        with open(cache_file, 'w') as cache:
            for line in format_output(output):
                cache.write(line)
    except IOError as e:
        logger.console_logger.error('Error while writing cache.')
        logger.console_logger.error(
            'Make sure base process has the right permissions and path exists for %s' % cache_file)
        raise


def read_cache(cache_file):
    """Read cache file and return content."""
    if not path.isfile(cache_file):
        return False
    with open(cache_file, 'r') as cache:
        return cache.read()


def main():
    """Fetching groups from AD and printing them in JSON."""
    cfg = configer.read_config('ad')
    ants_path = path.dirname(path.realpath(__file__))
    cache_file = cfg['cache_file']

    # Reading fully qualified host name and converting it to lower case
    fqdn = gethostname().lower()
    simple_host_name = get_simple_host_name(fqdn)

    # Connecting to Active Directory and check connection status
    ad_connection = connect_to_ad(
        cfg['ldap_user'], cfg['ldap_pw'], cfg['ldap_host'])
    online = bool(ad_connection)
    if online:
        # Initializing output
        output = dict()
        output[cfg['common_group']] = [fqdn]

        # Checking if host is in AD
        if host_exist_in_ad(ad_connection, simple_host_name,
                            cfg['ldap_ou_computers']):
            # Looking up computers distinguished name
            computer_dn = get_computer_dn(
                ad_connection, simple_host_name, cfg['ldap_ou_computers'])

            # Looking up computers groups
            computer_groups = get_computer_groups(ad_connection,
                                                  cfg['ldap_ou_groups'],
                                                  computer_dn,
                                                  cfg['group_prefix'])
            # Adding groups to output
            for group in computer_groups:
                output[group] = [fqdn]

        # Writing output to cache file
        try:
            write_cache(cache_file, output)
        except IOError as e:
            sys.exit(e)

        # Adding online Group after cache is written.
        # We do not want to cache this group
        output[cfg['common_ad_bound_group']] = [fqdn]

        # Printing output in Ansible JSON syntax
        print format_output(output)

    # We are not bound to AD we are offline
    else:
        # Reading cache file
        cached_output = read_cache(cache_file)
        if cached_output:
            # Printing cached
            print cached_output
        else:
            # Printing default group
            output = dict()
            output['common'] = [fqdn]
            print format_output(output)


if __name__ == '__main__':
    main()
