#!python

import a2conf
import argparse
import logging
import requests
import os
import socket
import random
import string
import sys
import subprocess
import pathlib

from a2conf import Node, VhostNotFound

log = None



class FatalError(Exception):
    pass

class LetsEncryptCertificateConfig:
    def __init__(self, path, webroot=None, domains=None):
        self.content = dict()
        if webroot:
            self.init_webroot(webroot, domains)
        else:
            self.init_readfile(path)

    def init_webroot(self, webroot, domains):
        self.path = '::internal::'
        self.content = dict()
        self.content['[[webroot_map]]'] = dict()
        for domain in domains:
            self.content['[[webroot_map]]'][domain] = webroot

    def init_readfile(self, path):
        self.path = path
        self.content = dict()
        self.content[''] = dict()
        section = ''

        with open(path) as f:
            for line in f:
                line = line.strip()
                if not line or line.startswith('#'):
                    # skip empty lines
                    continue

                if line.startswith('['):
                    # new section
                    section = line
                    self.content[section] = dict()
                else:
                    k, v = line.split('=')
                    k = k.strip()
                    v = v.strip()
                    self.content[section][k] = v

    @property
    def domains(self):
        try:
            return self.content['[[webroot_map]]'].keys()
        except KeyError:
            print("No [[webroot_map]] in {}".format(self.path))
            raise

    def get_droot(self, domain):
        return self.content['[[webroot_map]]'][domain]

    def dump(self):
        print(self.content)


class Report:
    def __init__(self, name):
        self.name = name
        self._info = list()
        self._problem = list()
        self.prefix = ' ' * 4
        self.objects = dict()

    def info(self, msg, object=None):
        if object:
            if msg not in self.objects:
                self.objects[msg] = list()
            self.objects[msg].append(object)
        else:
            self._info.append(msg)

    def problem(self, msg):
        self._problem.append(msg)

    def has_problems(self):
        return bool(len(self._problem))

    def report(self):
        if self._problem:
            print("=== {} PROBLEM ===".format(self.name))
        else:
            print("=== {} ===".format(self.name))

        if self._info or self.objects:
            print("Info:")
            for msg, objects in self.objects.items():
                print("{}({}) {}".format(self.prefix, ', '.join(objects), msg))

            for msg in self._info:
                print(self.prefix + msg)

        if self._problem:
            print("Problems:")
            for msg in self._problem:
                print(self.prefix + msg)

        print("---\n")


def detect_ip():
    url = 'http://ifconfig.me/'
    r = requests.get(url)
    if r.status_code != 200:
        log.error('Failed to get IP from {} ({}), use --ip a.b.c.d'.format(url, r.status_code))

    assert(r.status_code == 200)
    return ['127.0.0.1', r.text]


def resolve(name):
    """
    return list of IPs for hostname or raise error
    :param name:
    :return:
    """
    try:
        data = socket.gethostbyname_ex(name)
        return data[2]
    except socket.gaierror:
        log.warning("WARNING: Cannot resolve {}".format(name))
        return list()

def simulate_check(servername, droot, report):
    success = False
    test_data = ''.join(random.choice(
        string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(100))
    test_basename = 'certbot_diag_' + ''.join(random.choice(
        string.ascii_uppercase + string.ascii_lowercase + string.digits) for _ in range(10))
    test_dir = os.path.join(droot, '.well-known', 'acme-challenge')
    test_file = os.path.join(test_dir, test_basename)
    #report.info("Test file path: " + test_file)
    test_url = 'http://' + servername + '/.well-known/acme-challenge/' + test_basename
    #report.info("Test file URL: " + test_url)

    log.debug('create test file ' + test_file)
    os.makedirs(test_dir, exist_ok=True)
    with open(test_file, "w") as f:
        f.write(test_data)

    log.debug('test URL ' + test_url)
    try:
        r = requests.get(test_url, allow_redirects=True)
    except requests.RequestException as e:
        report.problem("URL {} got exception: {}".format(test_url, e))
    else:
        if r.status_code != 200:
            report.problem('URL {} got status code {}. Used DocumentRoot {}. Maybe Alias or RewriteRule working?'.format(
                test_url, r.status_code, droot))
        else:
            if r.text == test_data:
                report.info("Simulated check match root: {}".format(droot), object=servername)
                success = True
            else:
                report.problem("Simulated check fails root: {} url: {}".format(droot, test_url))

    os.unlink(test_file)
    return success


def is_local_ip(hostname, local_ip_list, report):
    iplist = resolve(hostname)
    for ip in iplist:
        if ip in local_ip_list:
            report.info('is local {}'.format(ip), object=hostname)
        else:
            report.problem('{} ({}) not local {}'.format(hostname, ip, local_ip_list))

def get_all_hostnames(hostname, apacheconf=None, root=None):
    names = list()
    vhost = next(yield_vhost(hostname, apacheconf, root=root))
    servername = next(vhost.children('ServerName')).args
    names.append(servername)
    for alias in vhost.children('ServerAlias'):
        names.extend(alias.args.split(' '))
    return names


def get_webroot(hostname, apacheconf, root=None):
    vhost = next(yield_vhost(hostname, apacheconf, root=root))
    return next(vhost.children('DocumentRoot')).args


def yield_vhost(domain, apacheconf=None, root=None):
    if root is None:
        root = a2conf.Node()
        root.read_file(apacheconf)

    for vhost in root.children('<VirtualHost>'):
        if '80' not in vhost.args:
            # log.debug('Skip vhost {}:{} (no 80 in {})'.format(vhost.path, vhost.line, vhost.args))
            continue
        try:
            servername = next(vhost.children('servername')).args
        except StopIteration:
            # log.debug('Skip vhost {}:{} (no ServerName)'.format(vhost.path, vhost.line))
            continue

        if domain.lower() == servername.lower():
            # return vhost
            yield vhost

        for alias in vhost.children('serveralias'):
            if domain.lower() in map(str.lower, alias.args.split(' ')):
                # return vhost
                yield vhost

    return None

def process_file(leconf_path, local_ip_list, args, leconf=None):

    report = Report(leconf_path or 'manual')

    try:

        if not leconf:
            report.info("LetsEncrypt conf file: " + leconf_path)
            if os.path.exists(leconf_path):
                lc = LetsEncryptCertificateConfig(path=leconf_path)
            else:
                report.problem("Missing LetsEncrypt conf file " + leconf_path)
                raise FatalError
        else:
            lc = leconf

        if args.host and not args.host in lc.domains:
            log.debug('Skip file {}: not found domain {}'.format(leconf_path, args.host))
            return

        # Local IP check
        for domain in lc.domains:
            log.debug("check domain {} from {}".format(domain, leconf_path or 'manual'))
            le_droot = lc.get_droot(domain)

            is_local_ip(domain, local_ip_list, report)
            vhost_list = list(yield_vhost(domain, args.apacheconf))

            if not vhost_list:
                report.problem('Not found domain {} in {}'.format(domain, args.apacheconf))
                raise FatalError

            if len(vhost_list) > 1:
                report.problem('Found {} virtualhost for {} in {}'.format(len(vhost_list), domain, args.apacheconf))
                raise FatalError

            vhost = vhost_list[0]

            report.info('Vhost: {}:{}'.format(vhost.path, vhost.line), object=domain)

            #
            # DocumentRoot exists?
            #
            droot = None
            try:
                droot = next(vhost.children('DocumentRoot')).args
            except StopIteration:
                report.problem("No DocumentRoot in vhost at {}:{}".format(vhost.path, vhost.line))
                raise FatalError
            else:
                if droot is not None and os.path.isdir(droot):
                    report.info("DocumentRoot: {}".format(droot), object=domain)
                else:
                    report.problem("DocumentRoot dir not exists: {} (problem!)".format(droot))

            #
            # Redirect check
            #
            try:
                r = next(vhost.children('Redirect'))
                rpath = r.args.split(' ')[1]
                if rpath in ['/', '.well-known']:
                    report.problem('Requests will be redirected: {} {}'.format(r, r.args))
            except StopIteration:
                # No redirect, very good!
                pass

            #
            # DocumentRoot matches?
            #

            if not args.altroot:
                # No altroot, simple check
                if os.path.realpath(le_droot) == os.path.realpath(droot):
                    report.info('DocumentRoot {} matches LetsEncrypt and Apache'.format(droot), object=domain)
                else:
                    report.problem(
                        'DocRoot mismatch for {}. Apache: {} LetsEncrypt: {}'.format(domain, droot, le_droot))
                # simulate anyway
                simulate_check(domain.lower(), droot, report)
            else:
                # AltRoot
                if os.path.realpath(le_droot) == os.path.realpath(args.altroot):
                    report.info('Domain name {} le root {} matches --altroot'.format(domain, le_droot))
                    simulate_check(domain.lower(), le_droot, report)
                elif os.path.realpath(le_droot) == os.path.realpath(droot):
                    report.info('Domain name {} le root {} matches DocumentRoot'.format(domain, le_droot))
                    simulate_check(domain.lower(), droot, report)
                else:
                    report.problem(
                        'DocRoot mismatch for {}. AltRoot: {} LetsEncrypt: {} Apache: {}'.format(
                            domain, args.altroot, le_droot, droot))

            log.debug("END OF ITER for {}".format(domain))

    except FatalError:
        pass
    # END OF FINISHED PART
    #
    # Final debug
    #
    if report.has_problems() or not args.quiet:
        report.report()

    return

def get_aliases(names, apacheconf):
    """
    gets all aliases from both 'names' and get_all_hostnames
    :param names: any alises to append
    :param apacheconf: apache config file name
    :return:
    """

    aliases = list(names)
    aliases.extend(get_all_hostnames(names[0], apacheconf))

    return set(aliases)


def place(apacheconfig, token, validation, domain):
    root = Node(apacheconfig)
    try:
        vhost = root.find_vhost(domain)
        droot = vhost.first('DocumentRoot').args
        token_dir = os.path.join(droot, '.well-known', 'acme-challenge')
        token_path = os.path.join(token_dir, token) 
        # create dir if needed
        pathlib.Path(token_dir).mkdir(parents=True, exist_ok=True)

        with open(token_path, "w") as fh:
            fh.write(validation)
        print(f"{domain} {token}")
    except VhostNotFound:
        print(f"Not found vhost for {domain}")
        return

def cleanup(apacheconfig, token, domain):
    root = Node(apacheconfig)
    try:
        vhost = root.find_vhost(domain)
        droot = vhost.first('DocumentRoot').args
        token_dir = os.path.join(droot, '.well-known', 'acme-challenge')
        token_path = os.path.join(token_dir, token) 
        if os.path.isfile(token_path):
            print("Remove token", token_path)
            os.unlink(token_path)
        
    except VhostNotFound:
        print(f"Not found vhost for {domain}")
        return


def main():
    global log

    def_apacheconf = '/etc/apache2/apache2.conf'
    def_lepath = '/etc/letsencrypt/renewal/'

    epilog = 'Examples:\n' \
             '# Verify all LetsEncrypt config\n' \
             '{me}\n\n'.format(me=sys.argv[0])

    epilog += "# Verify one LetsEncrypt certificate:\n " \
              "{me} --host example.com\n\n".format(me=sys.argv[0])

    epilog += "# Verify if cert could be requested (preparation). Existing certificate not needed (all manual):\n" \
              "{me} --prepare -d example.com -d www.example.com -w /var/www/virtual/example.com\n" \
              "# Or much more simpler (aliases and webroot will be guessed from apache config):\n" \
              "{me} --prepare -d example.com --aliases\n\n".format(me=sys.argv[0])

    epilog += "# Create certificate for example.com and all of it's aliases" \
        "(www.example.com, example.net, www.example.net)\n" \
        "{me} --create -d example.com --aliases\n" \
        "{me} --create -d example.com -d www.example.com -d example.net -d www.example.net\n" \
        "\n".format(me=sys.argv[0])

    parser = argparse.ArgumentParser(description='Apache2 / Certbot misconfiguration diagnostic', epilog=epilog,
                                     formatter_class=argparse.RawTextHelpFormatter)

    g = parser.add_argument_group('Check existent LetsEncrypt verification (if "certbot renew" fails)')
    g.add_argument('--host', default=None, metavar='HOST',
                   help='Process only letsencrypt config file for HOST. def: {}'.format(None))
    g.add_argument('--process', default=False, action='store_true', help='Process all letsencrypt certificates')
    g.add_argument('--lepath', default=def_lepath, nargs='?', dest='lepath', metavar='LETSENCRYPT_DIR_PATH',
                   help='Lets Encrypt directory def: {}'.format(def_lepath))

    g = parser.add_argument_group('Check non-existent LetsEncrypt verification (if "certbot certonly --webroot" fails)')
    g.add_argument('--prepare', default=False, action='store_true',
                   help='Preparation check (before requesting LetsEncrypt cert). '
                        'You may also use --aliases option')
    g.add_argument('-w', '--webroot', help='DocumentRoot for new website')
    g.add_argument('-d', '--domain', nargs='*', metavar='DOMAIN', help='hostname/domain for new website')

    g = parser.add_argument_group('General options')
    g.add_argument('--altroot', default=None, metavar='DocumentRoot',
                   help='Try also other root (in case if Alias used). def: {}'.format(None))
    g.add_argument('-c', '--conf', dest='apacheconf', nargs='?', default=def_apacheconf, metavar='PATH',
                   help='Config file path (def: {})'.format(def_apacheconf))
    g.add_argument('-v', '--verbose', action='store_true',
                   default=False, help='verbose mode')
    g.add_argument('-q', '--quiet', action='store_true',
                   default=False, help='quiet mode, suppress output for sites without problems')
    g.add_argument('-i', '--ip', nargs='*',
                   help='Default addresses. Autodetect if not specified')
    g.add_argument(metavar='CERTBOT_OPTS', nargs='*', dest='opts',
                   help='options to pass to certbot, e.g. --test-cert')

    g = parser.add_argument_group('Generate new certificate (certbot certonly --webroot)')
    g.add_argument('--create', default=False, action='store_true',
                   help='Create LetsEncrypt certificate (via certbot). Use also -d and --aliases')
    g.add_argument('--aliases', action='store_true',
                   default=False, help='Include all ServerName and ServerAlias found in VirtualHost')
    g.add_argument('--root', default=None, metavar='DocumentRoot', dest='docroot', help='DocumentRoot for HTTP site (if --create both)')

    g = parser.add_argument_group('Assisting in  (certbot certonly --manual)')
    g.add_argument('--place', nargs=2, metavar=('CERTBOT_TOKEN', 'CERTBOT_VALIDATION'),
                   help='place validation value into token file')
    g.add_argument('--cleanup', metavar='CERTBOT_TOKEN', help='cleanup validation token')


    args = parser.parse_args()

    logging.basicConfig(
        # format='%(asctime)s %(message)s',
        format='%(message)s',
        # datefmt='%Y-%m-%d %H:%M:%S',
        level=logging.INFO)

    log = logging.getLogger('diag')

    if args.verbose:
        log.setLevel(logging.DEBUG)
        log.debug('Verbose mode')

    if args.ip:
        local_ip_list = args.ip
    else:
        log.debug("Autodetect IP")
        local_ip_list = detect_ip()
    log.debug("my IP list: {}".format(local_ip_list))


    if args.prepare:
        if not args.domain:
            print("--prepare requires at least one -d and optional --aliases")
            sys.exit(1)
        aliases = get_aliases(args.domain, args.apacheconf) if args.aliases else args.domain
        webroot = args.webroot or get_webroot(args.domain[0], args.apacheconf)
        lc = LetsEncryptCertificateConfig(path=None, webroot=webroot, domains=aliases)
        process_file(leconf_path=None, local_ip_list=local_ip_list, args=args, leconf=lc)


    elif args.create:

        if not args.domain:
            print("--domain requires at least one -d and optional --aliases")
            sys.exit(1)

        name = args.domain[0]

        print("Create cert for {}".format(name))
        
        try:
            webroot = get_webroot(name, args.apacheconf)
        except StopIteration as e:
            print("Cannot find webroot for {} (starting from {}). Maybe site is not yet created/enabled?".format(
                name, args.apacheconf))
            return

        aliases = get_aliases(args.domain, args.apacheconf) if args.aliases else list()

        cmd = ['certbot', 'certonly', '--webroot', '-w', webroot, '-d', name]

        if args.aliases:
            # add all aliases
            # remove main name from aliases
            if name in aliases:
                aliases.remove(name)
            for alias in aliases:
                cmd.extend(['-d', alias])
        else:
            # add all other names
            for domain in args.domain[1:]:
                cmd.extend(['-d', domain])

        if args.opts:
            cmd.extend(args.opts)

        print("RUNNING: {}".format(' '.join(cmd)))
        cp = subprocess.run(cmd)
        sys.exit(cp.returncode)

    elif args.place:
        place(args.apacheconf, args.place[0], args.place[1], args.domain[0])
    elif args.cleanup:
        cleanup(args.apacheconf, args.cleanup, args.domain[0])

    elif args.process and os.path.isdir(args.lepath):
        for f in os.listdir(args.lepath):
            path = os.path.join(args.lepath, f)
            if not path.endswith('.conf'):
                continue
            if not (os.path.isfile(path) or os.path.islink(path)):
                continue
            process_file(leconf_path=path, local_ip_list=local_ip_list, args=args)
    elif args.process and os.path.isfile(args.lepath):
        process_file(leconf_path=args.lepath, local_ip_list=local_ip_list, args=args)
    else:
        print("Need command")


main()
