#!python

from a2conf import Node
import argparse
import sys
import subprocess
import os
import socket
import requests

class MyException(Exception):
    pass

class VhostNotFound(MyException):
    pass

class ArgumentError(MyException):
    pass

#
# subroutines
#
def get_all_hostnames(vhost):
    names = list()
    try:
        servername = next(vhost.children('ServerName')).args
        names.append(servername)
    except StopIteration:
        pass

    for alias in vhost.children('ServerAlias'):
        names.extend(alias.args.split(' '))
    return names

def get_vhost_by_host(root, host, arg=None):
    for vhost in root.children('<VirtualHost>'):
        if arg and not arg in vhost.args:
            continue
        if host in get_all_hostnames(vhost):
            return vhost
    raise VhostNotFound('Vhost args: {} host: {} not found'.format(arg, host))
    
def get_vhconfig(apacheconfig, vhspec, domainlist):
    """
        return path to virtualhost apache config
    """

    if not domainlist:
        raise ArgumentError("Domain (-d) is required")

    root = Node(apacheconfig)
    root = Node(get_vhost_by_host(root, domainlist[0], vhspec).path)
    return root.path

def guess_webroot(basename):
    prefixes = ['/var/www/virtual', '/var/www', '/usr/local/apache/htdocs', '/home/www/htdocs']
    for prefix in prefixes:
        if os.path.isdir(prefix):
            return os.path.join(prefix, basename)
    raise StopIteration

def guess_apacheconfig(hostname):
    return os.path.join('/etc/apache2/sites-available/', hostname+'.conf')
    

#
# methods
#

def list_vhosts(apacheconf):
    root = Node(apacheconf)
    for vhost in root.children('<VirtualHost>'):

        hostnames = get_all_hostnames(vhost)

        try:
            docroot = vhost.first('DocumentRoot').args 
        except AttributeError:
            docroot = None
            
        print(vhost.path+':'+str(vhost.line), vhost.args, docroot, ' '.join(hostnames))

def make_both(apacheconfig, config,  domainlist, webroot, auto):

    def reload():
        print(".. reload")
        # subprocess.run(['apache2', '-k', 'graceful'])
        subprocess.run(['systemctl', 'reload', 'apache2'])

    if auto and not config:
        # generate config file name
        config = guess_apacheconfig(domainlist[0])
        print('# [auto] will use VirtualHost config file {}'.format(config))

    print(".. basic")
    make_basic(apacheconfig, config,  domainlist, webroot, auto)
    
    print(".. enable")
    subprocess.run(['a2ensite', os.path.basename(config)])
    reload()
    

    # re-read webroot from config if it was auto-generate
    root = Node(apacheconfig)
    vhost = get_vhost_by_host(root, domainlist[0], ':80')
    webroot = vhost.first('DocumentRoot').args

    print(".. request certificate")
    cmd = [ 'certbot', 'certonly', '--webroot', '-w', webroot ]
    for domain in domainlist:
        cmd.append('-d')
        cmd.append(domain)

    cp = subprocess.run(cmd)
    assert(cp.returncode == 0)

    print(".. convert")
    make_convert(apacheconfig, domainlist)
    reload()

    print(".. redirect")
    make_redirect(apacheconfig, domainlist)        
    reload()

def make_basic(apacheconfig, config,  domainlist, webroot, auto):
    # sanity check 
    if not domainlist:
        print("Specify at least one domain name: -d example.com")
        sys.exit(1)
    if not config:
        if auto:
            # generate config file name
            config = guess_apacheconfig(domainlist[0])
            print('# [auto] will use VirtualHost config file {}'.format(config))
        else:
            print("Specify VirtualHost config file, e.g. -c /etc/apache2/sites-available/{}.conf or use --auto".format(domainlist[0]))
            sys.exit(1)

    if not webroot:
        if auto:
            try:
                webroot = guess_webroot(domainlist[0])
                    
                print('# [auto] will use webroot {}'.format(webroot))
                if not os.path.isdir(webroot):
                    print('# [auto] create webroot {}'.format(webroot))
                    os.mkdir(webroot)
            except StopIteration:
                print("Could not autodetect webroot, use -w <WEBROOT>")
                return
        else:
            try:
                def_webroot = guess_webroot(domainlist[0])
                print("Need webroot (DocumentRoot) path, e.g.: -w /var/www/{} or use --auto to use/create {}"
                    .format(domainlist[0], def_webroot))
            except StopIteration:
                print("Need webroot (DocumentRoot) path, e.g.: -w /var/www/{}"
                    .format(domainlist[0]))
            sys.exit(1)


    # check maybe it exists
    root = Node(apacheconfig)
    for vhost in root.children('<VirtualHost>'):
        vhnames = get_all_hostnames(vhost)
        for host in domainlist:
            if host in vhnames:
                print("Problem: host {} found in vhost {}:{}".format(host, vhost.path, vhost.line))
                sys.exit(1)

    # Good, now, create it finally!
    if os.path.exists(config):
        root = Node(config)
        root.insert('')
    else:
        root = Node()

    comment = '# This VirtualHost was created by a2vhost utility'

    new_vhost=Node(raw='<VirtualHost *:80>')
    new_vhost.insert('ServerName {}'.format(domainlist[0]))
    if len(domainlist) > 1:
        new_vhost.insert('ServerAlias {}'.format(' '.join(domainlist[1:])))
    new_vhost.insert('DocumentRoot {}'.format(webroot))

    root.insert([comment, '', new_vhost])
    root.write_file(config)


def make_convert(apacheconfig, domainlist):
    # sanity check 
    if not domainlist:
        print("Specify at least one domain name: -d example.com")
        sys.exit(1)

    # check maybe SSL host already exists
    root = Node(apacheconfig)
    for vhost in root.children('<VirtualHost>'):
        if not ':443' in vhost.args:
            continue
        vhnames = get_all_hostnames(vhost)
        for host in domainlist:
            if host in vhnames:
                print("Problem: host {} found in vhost {}:{}".format(host, vhost.path, vhost.line))
                sys.exit(1)

    # get proper root node
    # stage 1: start from main config, e.g. /etc/apache2/apache2.conf
    root = Node(apacheconfig)
    root = Node(get_vhost_by_host(root, domainlist[0], ':80').path)
    vhost = get_vhost_by_host(root, domainlist[0], ':80')
    
    # make block and insert it!
    ssl_block = [
        '',
        'SSLEngine On',
        'SSLCertificateFile /etc/letsencrypt/live/{}/fullchain.pem'.format(domainlist[0]),
        'SSLCertificateKeyFile /etc/letsencrypt/live/{}/privkey.pem'.format(domainlist[0]),
        'Header always set Strict-Transport-Security "max-age=31536000; includeSubDomains"'
    ]

    vhost.insert(ssl_block, after='DocumentRoot')
    vhost_args = vhost.args

    vhost.args = vhost.args.replace(':80',':443')    
    root.write_file(root.path)

def make_redirect(apacheconfig, domainlist):
    # sanity check 
    if not domainlist:
        raise ArgumentError("Specify at least one domain name: -d example.com")

    # check maybe SSL host already exists
    root = Node(apacheconfig)
    for vhost in root.children('<VirtualHost>'):
        if not ':80' in vhost.args:
            continue
        vhnames = get_all_hostnames(vhost)
        for host in domainlist:
            if host in vhnames:
                print("Problem: host {} found in vhost {}:{}".format(host, vhost.path, vhost.line))
                sys.exit(1)

    # get proper root node
    root = Node(apacheconfig)
    root = Node(get_vhost_by_host(root, domainlist[0], ':443').path)
    vhost = get_vhost_by_host(root, domainlist[0], ':443')
    servername = vhost.first('servername')
    documentroot = vhost.first('documentroot')

    new_vhost=Node(raw='<VirtualHost *:80>')
    new_vhost.insert(servername)
    for alias in vhost.children('serveralias'):
        new_vhost.insert(alias)
    new_vhost.insert(documentroot)
    
    new_vhost.insert('RewriteEngine On')
    new_vhost.insert('RewriteCond %{HTTPS} !=on')
    new_vhost.insert('RewriteCond %{REQUEST_URI} !^/\.well\-known')
    new_vhost.insert('RewriteRule (.*) https://%{SERVER_NAME}$1 [R=301,L]')

    root.insert(['', '# auto-generated plain HTTP site for redirect', new_vhost], after=vhost)
    root.write_file(root.path)

def add_directive(apacheconfig, vhconfig, domain, vhspec, directive):
    
    vhconfig = vhconfig or get_vhconfig(apacheconfig, vhspec, domain)
    
    root = Node(vhconfig)
    vhost = get_vhost_by_host(root, domain[0],arg=vhspec)        
    vhost.insert(directive)
    root.write_file(root.path)

def rm_directive(apacheconfig, vhconfig, domain, vhspec, directive):
    
    try:
        cmd, args = directive.split(' ', maxsplit=1)
    except ValueError:
        cmd = directive
        args = None

    vhconfig = vhconfig or get_vhconfig(apacheconfig, vhspec, domain)
    
    root = Node(vhconfig)
    vhost = get_vhost_by_host(root, domain[0],arg=vhspec)        
    
    deleted = list()
    for n in vhost.children(cmd):
        if args:
            if args==n.args:
                deleted.append(n)
        else:
            deleted.append(n)
    
    for n in deleted:
        print("removed:", n, n.args)
        n.delete()

    root.write_file(root.path)

def delete_vhost(apacheconfig, vhconfig, domain, vhspec):

    vhconfig = vhconfig or get_vhconfig(apacheconfig, vhspec, domain)
    root = Node(vhconfig)
    vhost = get_vhost_by_host(root, domain[0],arg=vhspec)
    vhost.delete()
    root.write_file(root.path)

def dump_vhost(apacheconfig, vhconfig, domain, vhspec, verbose=False):

    vhconfig = vhconfig or get_vhconfig(apacheconfig, vhspec, domain)
    
    root = Node(vhconfig)
    vhost = get_vhost_by_host(root, domain[0], arg=vhspec)
    if verbose:
        print('# {path}:{line}'.format(path=vhost.path, line=vhost.line))
    vhost.dump()

def check_vhosts(config, localip, verbose):

    def title(vhost):
        servername = vhost.first('servername').args
        return "VHOST {}:{} {}".format(vhost.path, vhost.line, servername)

    def ipv4_addresses(host):
        return list( map( lambda x: x[4][0], socket.getaddrinfo(host,1,type=socket.SOCK_STREAM, family=socket.AF_INET)))

    if localip is None:
        localip = [ requests.get('http://ifconfig.me/').text ]

    root = Node(config)
    for vhost in root.children('<VirtualHost>'):
        droot_node = vhost.first('documentroot')
        servername = vhost.first('servername')
        vhnames = get_all_hostnames(vhost)

        if verbose:
            print(".. check", title(vhost))

        errors = list()

        if droot_node is None:
            errors.append("no document root")
        else:
            if not os.path.isdir(droot_node.args):
                errors.append("missing document root {}".format(droot_node.args))

        failed_list=list()
        ok_list=list()

        failedhost=0
        for host in vhnames:
            try:
                addresses = ipv4_addresses(host)
            except socket.gaierror:
                errors.append("hostname {} not resolving".format(host))
                failedhost += 1
                failed_list.append(host)
            else:
                if all(i in localip for i in addresses):
                    ok_list.append(host)
                    if verbose:
                        print(".. host {} is valid".format(host))
                else:
                    failedhost += 1
                    failed_list.append(host)
                    errors.append("hostname {} points to other IP {}".format(host, addresses))
                

        if failedhost>0 or verbose:
            errors.append("{}/{} hosts failed. failed: {} works: {}".format(failedhost, len(vhnames), failed_list, ok_list))

        if errors:
            print("-------------")
            print(title(vhost))
            print("-------------")
            for e in errors:
                print(e)
            print()


def list_aliases(config, domainlist, vhspec):
    root = Node(config)
    maindomain = domainlist[0]
    allnames = set()

    for domain in domainlist:
        vhost = get_vhost_by_host(root, domain, arg=vhspec)
        allnames.update(get_all_hostnames(vhost))

    allnames.remove(maindomain)
    outlist = [ maindomain ]
    outlist.extend(allnames)
    print(' '.join(outlist))

    """
    for vhost in root.children('<VirtualHost>'):
        droot_node = vhost.first('documentroot')
        servername = vhost.first('servername')
        vhnames = get_all_hostnames(vhost)
    """




def main():

    def_apacheconf = '/etc/apache2/apache2.conf'

    parser = argparse.ArgumentParser(description='Apache2 CLI vhost manager')

    g = parser.add_argument_group('Show vhost')
    g.add_argument('--list', default=False, action='store_true',
                   help='List VirtualHosts')
    g.add_argument('--dump', default=False, action='store_true',
                   help='Dump vhost content')
    g.add_argument('--check', default=False, action='store_true',
                   help='Check all vhosts for typical config problems')
    g.add_argument('--aliases', default=False, action='store_true',
                   help='Show all hostnames for --domain(s)')

    g = parser.add_argument_group('Manage vhosts')

    g.add_argument('--basic', default=False, action='store_true',
                   help='Create basic HTTP site (use: -c and --domain, --webroot)')
    g.add_argument('--convert', default=False, action='store_true',
                   help='Convert HTTP --domain site to HTTPS')
    g.add_argument('--redirect', default=False, action='store_true',
                   help='Create HTTP --domain redirect vhost to HTTPS')
    g.add_argument('--both', default=False, action='store_true',
                   help='Create both http/https sites and request certificate')
    g.add_argument('--delete', default=False, action='store_true',
                   help='Delete vhost (use -d, --vhost and (optionally) -c/-a')

    g = parser.add_argument_group('Manage vhost configuration')
    g.add_argument('--add', metavar='DIRECTIVE', 
                   help='Add custom apache configuration directive to vhost')
    g.add_argument('--rm', metavar='DIRECTIVE', 
                   help='Remove custom apache configuration directive to vhost')

    g = parser.add_argument_group('Options')
    g.add_argument('-a', '--apacheconfig', default=def_apacheconf, metavar='CONF',
                   help='Main apache config file. def: {}'.format(def_apacheconf))
    g.add_argument('-c', '--config', default=None, metavar='VHOST_CONF',
                   help='VirtualHost config file.')
    g.add_argument('-w', '--webroot', default=None, metavar='PATH',
                   help='Webroot (DocumentRoot) for new site')
    g.add_argument('-v', '--verbose', default=False, action='store_true',
                   help='Be verbose')
    g.add_argument('-d', '--domain', nargs='*', metavar='DOMAIN', help='hostname/domain(s) for new website')
    g.add_argument('--vhost', metavar='VHOST', help='Process only this vhost (e.g. *:80)')
    g.add_argument('--auto', default=False, action='store_true', help='Autogenerate/guess missing params " \
        "(config file name, create missing directories)')
    g.add_argument('--localip', metavar='IPv4', nargs='+', help='Local IP address(es) for --check')


    args = parser.parse_args()

    try:
        if args.list:
            list_vhosts(args.apacheconfig)
        elif args.check:
            check_vhosts(args.apacheconfig, args.localip, verbose=args)            
        elif args.dump:
            dump_vhost(args.apacheconfig, args.config, args.domain, args.vhost, verbose=args.verbose)

        elif args.basic:
            make_basic(args.apacheconfig, args.config,  args.domain, args.webroot, args.auto)
        elif args.both:
            make_both(args.apacheconfig, args.config,  args.domain, args.webroot, args.auto)
        elif args.convert:
            make_convert(args.apacheconfig, args.domain)
        elif args.redirect:
            make_redirect(args.apacheconfig, args.domain)
        elif args.delete:
            delete_vhost(args.apacheconfig, args.config, args.domain, args.vhost)

        elif args.add:
            add_directive(args.apacheconfig, args.config, args.domain, args.vhost, args.add)
        elif args.rm:
            rm_directive(args.apacheconfig, args.config, args.domain, args.vhost, args.rm)
        elif args.aliases:
            list_aliases(args.apacheconfig, args.domain, args.vhost)

    except (MyException) as e:
        print("Error:", e)
        return

main()