#!/usr/bin/python3

import argparse
import atexit
import datetime
import functools
import hashlib
import ipaddress
import os
import platform
import pwd
import signal
import socket
import stat
import subprocess
import sys
import textwrap
import time
import urllib.request
import zipfile
from collections import OrderedDict

from basefs import utils, validators

try:
    from basefs import exceptions
    from basefs.config import get_or_create_config, defaults, get_port
    from basefs.management import resources, run, mount, bootstrap, get
    from basefs.management.utils import get_cmd_port
    from basefs.keys import Key
    from basefs.logs import Log
    from basefs.state import BlockState
    from basefs.views import View
except ImportError:
    if len(sys.argv) > 1 and sys.argv[1] != 'installserf':
        raise


def send_command(cmd, *args, port=None):
    if not port:
        raise ValueError("Missing port value")
    data = ''
    for part in utils.netcat('127.0.0.1', port, ' '.join(('c' + cmd,) + args).encode()):
        data += part
    if not data:
        sys.stderr.write("No data received from 127.0.0.1:%i\n" % port)
    return data


def genkey():
    # genkey [<keypath>] [-f]
    genkey_parser.add_argument('keypath', nargs='?',
        default=defaults.keypath,
        help='Path to the EC private key. %s by default.' % defaults.keypath)
    genkey_parser.add_argument('-f', '--force', dest='force', action='store_true',
        help='Rewrite key file if present.')
    args = genkey_parser.parse_args()
    keypath = args.keypath
    keydir = os.path.dirname(keypath)
    if not os.path.exists(keydir):
        if keypath == defaults.keypath:
            os.mkdir(keydir)
        else:
            sys.stderr.write("Error: %s keypath directory doesn't exist, create it first.\n" % keydir)
            sys.exit(2)
    elif not args.force and os.path.exists(keypath):
        sys.stderr.write('Error: %s key already exists, use --force to override it.\n' % keypath)
        sys.exit(2)
    key = Key.generate()
    key.save(keypath)
    config = get_or_create_config(defaults)
    if 'keypath' not in config['DEFAULT']:
        config['DEFAULT']['keypath'] = keypath
        config.save()
    sys.stdout.write("Generate EC key on %s\n" % keypath)
    sys.exit()


def keys():
    keys_parser.add_argument('logpath', nargs='?', default=default_logpath,
        help='Path to the basefs log file, uses %s by default.' % default_logpath,
        type=validators.file_exists(keys_parser, name='logpath'))
    keys_parser.add_argument('-p', '--path', dest='path', default='/',
        help='Base path.')
    keys_parser.add_argument('-d', '--by-dir', dest='by_dir', action='store_true',
        help='List keys by dir instead of by key.')
    args = keys_parser.parse_args()
    log = Log(args.logpath)
    log.load()
    view = View(log)
    view.build()
    keys = view.get_keys(path=args.path, by_dir=args.by_dir)
    for key, values in keys.items():
        sys.stdout.write(key + '\n')
        for value in values:
            sys.stdout.write('    ' + value + '\n')


def list_cmd():
    config = get_or_create_config(defaults)
    data = [
        ['Name', 'Logpath', 'Def. port', 'Moutpoint', 'Members', 'Logsize', 'Last update'],
        ['----', '-------', '---------', '---------', '-------', '-------', '-----------'],
    ]
    def get_info(logpath):
        statinfo = os.stat(logpath)
        size = utils.sizeof_fmt(statinfo.st_size)
        mtime = datetime.datetime.fromtimestamp(statinfo.st_mtime).strftime('%Y-%m-%d %H:%M:%S')
        minfo, mpoint = utils.get_mountpoint(logpath)
        members = '-'
        if minfo:
            mport = int(minfo.split(':')[1])
            if mport != port:
                mpoint += ':%i' % mport
            members = str(len(send_command('MEMBERS', port=mport).splitlines()))
        else:
            mpoint = 'Not mounted'
        return mpoint, members, size, mtime
    names = []
    for section, content in config.items():
        if section != 'DEFAULT':
            names.append(section)
            logpath = content['logpath']
            port = content['port']
            try:
                mpoint, members, size, mtime = get_info(logpath)
            except FileNotFoundError:
                pass
            else:
                data.append([section, logpath, port, mpoint, members, size, mtime])
    for log in os.listdir(defaults.logdir):
        if log not in names:
            name = log
            logpath = os.path.join(defaults.logdir, log)
            port = str(get_port(name))
            mpoint, members, size, mtime = get_info(logpath)
            data.append([name, logpath, port, mpoint, members, size, mtime])
    sys.stdout.write(utils.tabluate(data) + '\n')


def show():
    show_parser.add_argument('logpath',
        type=validators.name_or_logpath(show_parser, get_or_create_config(defaults), defaults),
        help='Log name or logpath')
    show_parser.add_argument('path', nargs='?', default=os.sep,
        help='Path to the basefs log file, uses / by default.')
    show_parser.add_argument('-a', '--ascii', dest='ascii', action='store_true',
        help='use ASCII line drawing characters')
    show_parser.add_argument('-c', '--color', dest='color', action='store_true',
        help='use terminal coloring')
    args = show_parser.parse_args()
    config = get_or_create_config(defaults)
    logpath = args.logpath
    if logpath in config:
        logpath = config[args.logpath]['logpath']
    elif not os.path.exists(logpath) and os.path.exists(os.path.join(defaults.logdir, logpath)):
        logpath = os.path.join(defaults.logdir, logpath)
    log = Log(logpath)
    log.load()
    view = View(log)
    view.build()
    
    printed = False
    def print_tree(entry):
        tree = log.print_tree(entry=entry, view=view, color=args.color, ascii=args.ascii)
        if args.ascii:
            sys.stdout.buffer.write(tree.encode('ascii', errors='replace'))
        else:
            sys.stdout.write(tree)
        return True
    
    mount_info = utils.get_mount_info()
    if mount_info:
        path = os.path.abspath(args.path)
        path = path.replace(mount_info.mountpoint, '')
        try:
            entry = log.find(path)
        except exceptions.DoesNotExist:
            entry = log.find(args.path)
        else:
            printed = print_tree(entry)
    else:
        try:
            entry = log.find(args.path)
        except exceptions.DoesNotExist:
            pass
        else:
            printed = print_tree(entry)
    if not printed:
        sys.stderr.write("Error: '%s' path does not exist on the log.\n" % args.path)
        sys.exit(2)


def grant():
    # grant <keypath/keyoneliner/keyfingerprint/keyname> <fspath>
#    grant_parser.add_argument('logpath', nargs='?', default=default_logpath,
#        help='Path to the basefs log file, uses %s by default.' % default_logpath,
#        type=lambda v: file_exists(grant_parser, v, name='logpath'))
    grant_parser.add_argument('grantpath',
        help='Path where the permission should be granted.')
    grant_parser.add_argument('grantkey',
        help='Key fingerprint, if exists on lskeys, or path to a public key.',
        type=validators.key(grant_parser))
#    grant_parser.add_argument('-k', '--key', dest='key',
#        default=default_keypath,
#        help='Path to your EC private key. %s by default.' % default_keypath,
#        type=lambda v: key(grant_parser, v))
    args = grant_parser.parse_args()
    mount_info = utils.get_mount_info(os.path.abspath(args.grantpath))
    if mount_info is None:
        sys.stderr.write("Error: grantpath '%s' is not a basefs mountpoint subdir\n" % args.grantpath)
        sys.exit(2)
    path = os.path.relpath(args.grantpath, mount_info.mountpoint)
    if path.startswith('.'):
        path = path[1:]
    path = '/' + path
    context = {
        'path': path,
        'grant_key': args.grantkey.oneliner(),
    }
    grant_cmd = 'c GRANT %(path)s %(grant_key)s' % context
    response = utils.netcat('127.0.0.1', mount_info.port, grant_cmd)
    sys.stdout.write(response)
    sys.exit()


def revoke():
    # revoke <keyfinger/keyname> <fspath>
#    revoke_parser.add_argument('logpath', nargs='?', default=default_logpath,
#        help='Path to the basefs log file, uses %s by default.' % default_logpath,
#        type=lambda v: file_exists(revoke_parser, v, name='logpath'))
    grant_parser.add_argument('revokekey',
        help='Key fingerprint, if exists on lskeys, or path to a public key.',
        type=validators.fingerprint(revoke_parser))
    revoke_parser.add_argument('revokepath', nargs='?', default='/',
        help='Path where the permission should be granted. Defaults to /.')
#    grant_parser.add_argument('-k', '--key', dest='key',
#        default=default_keypath,
#        help='Path to your EC private key. %s by default.' % default_keypath,
#        type=lambda v: key(revoke_parser, v))
    args = revoke_parser.parse_args()
    log = Log(args.logpath)
    view = View(log, args.key)
    view.revoke(args.revokepath, args.revokekey.fingerprint)
    sys.exit()


def revert():
#    revert_parser.add_argument('logpath', nargs='?', default=default_logpath,
#        help='Path to the basefs log file, uses %s by default.' % default_logpath,
#        type=lambda v: file_exists(revoke_parser, v, name='logpath'))
    revert_parser.add_argument('path',
        help='Path of the directory or file to revert')
    revert_parser.add_argument('hash',
        help="Hash of a previous revision, use 'basefs log path' for showing all revisions")
#    revert_parser.add_argument('-k', '--key', dest='key',
#        default=default_keypath,
#        help='Path to your EC private key. %s by default.' % default_keypath,
#        type=lambda v: key(revoke_parser, v))
    args = revert_parser.parse_args()
    log = Log(args.logpath)
    view = View(log, args.key)
    view.revert(args.path, args.hash)
    sys.exit()


def blocks():
    blocks_parser.add_argument('name', nargs='?', help='Name')
    args = blocks_parser.parse_args()
    mount_info = utils.get_mount_info()
    port = get_cmd_port(args, mount_info, defaults)
    while True:
        result = send_command('BLOCKSTATE', port=port)
        sys.stdout.write(result + '\n')
        time.sleep(1)
    sys.exit()


def members():
    members_parser.add_argument('name', nargs='?', help='Name')
    args = members_parser.parse_args()
    mount_info = utils.get_mount_info()
    port = get_cmd_port(args, mount_info, defaults)
    members = send_command('MEMBERS', port=port)
    sys.stdout.write(members + '\n')
    sys.exit()


def serf():
    serf_parser.add_argument('name', help='Name')
    serf_parser.add_argument('cmd', help='Serf command')
    serf_parser.add_argument('args', metavar='N', nargs='*', help='Serf command arguments')
    args = serf_parser.parse_args()
    mount_info = utils.get_mount_info()
    cmd_port = get_cmd_port(args, mount_info, defaults)
    if args.cmd in ('state',):
        result = send_command('SERF'+args.cmd.upper(), port=cmd_port)
        sys.stdout.write(result + '\n')
    else:
        rpc_port = cmd_port-1
        os.system('serf %s --rpc-addr=127.0.0.1:%i "%s"' % (args.cmd, rpc_port, ' '.join(args.args)))


def installserf():
    installserf_parser.add_argument('binpath', nargs='?', default='/usr/local/bin/',
        type=validators.dir_exists(installserf_parser, name='binpath'),
        help='Binpath for installing serf, defaults to /usr/local/bin/')
    architectures = {
        'x86_64': 'amd64',
        'AMD64': 'amd64',
        'i686': '386',
        'i586': '386',
        'i386': '386',
        'x86': '386',
    }
    args = installserf_parser.parse_args()
    arch = architectures[platform.machine()]
    url = "https://releases.hashicorp.com/serf/0.6.4/serf_0.6.4_linux_%s.zip" % arch
    sys.stdout.write("Donwloading %s ...\n" % url)
    path, headers = urllib.request.urlretrieve(url)
    try:
        serf_path = os.path.join(args.binpath, 'serf')
        sys.stdout.write("Unpacking into %s ...\n" % serf_path)
        if os.path.exists(serf_path):
            os.remove(serf_path)
        zip_ref = zipfile.ZipFile(path, 'r')
        zip_ref.extractall(args.binpath)
        zip_ref.close()
        st = os.stat(serf_path)
        sys.stdout.write("changing mode of %s to 775\n" % serf_path)
        os.chmod(serf_path, st.st_mode | 0o0111)
    finally:
        os.remove(path)


def help():
    commands = []
    max_key = 0
    for key in methods.keys():
        max_key = max(len(key), max_key)
    tabs = int((max_key+4)/8)
    for key, value in methods.items():
        method, parser = value
        head = '    ' + key
        indent = '\t'*(tabs - int(len(head)/8) + 1)
        commands.append(head + (indent + parser.description if parser else ''))
    sys.stdout.write(textwrap.dedent("""\
        Usage: basefs COMMAND [arg...]
               basefs [ --help | -v | --version ]
        
        Basically Available, Soft state, Eventually consistent File System.
        
        Commands:
        %s
        
        Run 'basefs COMMAND --help' for more information on a command
        """) % '\n'.join(commands))
    sys.exit()


genkey_parser = argparse.ArgumentParser(
    description='Generate a new EC private key',
    prog='basefs genkey')
keys_parser = argparse.ArgumentParser(
    description='List keys and their directories',
    prog='basefs keys')
list_parser = argparse.ArgumentParser(
    description='List all available logs',
    prog='basefs list')
show_parser = argparse.ArgumentParser(
    description='Show a log file using a tree representation',
    prog='basefs show')
grant_parser = argparse.ArgumentParser(
    description='Grant key write permission',
    prog='basefs grant')
revoke_parser = argparse.ArgumentParser(
    description='Revoke key write permission',
    prog='basefs revoke')
revert_parser = argparse.ArgumentParser(
    description="Revert object to previous state, 'log' command lists all revisions",
    prog='basefs revert')
blocks_parser = argparse.ArgumentParser(
    description="Block state",
    prog='basefs blocks')
members_parser = argparse.ArgumentParser(
    description="List cluster members",
    prog='basefs members')
serf_parser = argparse.ArgumentParser(
    description="Serf RPC Command proxy",
    prog='basefs serf'
)
installserf_parser = argparse.ArgumentParser(
    description="Download and install Serf",
    prog='basefs installserf')


if __name__ == '__main__':
    if len(sys.argv) > 1 and sys.argv[1] == 'installserf':
        sys.argv.pop(1)
        installserf()
    else:
        methods = OrderedDict([
            ('mount', (mount.command, mount.parser)),
            ('run', (run.command, run.parser)),
            ('bootstrap', (bootstrap.command, bootstrap.parser)),
            ('genkey', (genkey, genkey_parser)),
            ('keys', (keys, keys_parser)),
            ('grant', (grant, grant_parser)),
            ('revoke', (revoke, revoke_parser)),
            ('list', (list_cmd, list_parser)),
            ('show', (show, show_parser)),
            ('revert', (revert, revert_parser)),
            ('blocks', (blocks, blocks_parser)),
            ('members', (members, members_parser)),
            ('serf', (serf, serf_parser)),
            ('get', (get.command, get.parser)),
            ('installserf', (installserf, installserf_parser)),
            ('resources', (resources.command, resources.parser)),
            ('help', (help, None)),
        ])
        if len(sys.argv) > 1:
            method = sys.argv.pop(1)
            if method == '--help':
                method = 'help'
            try:
                method = methods[method][0]
            except KeyError:
                sys.stdout.write("Error: not recognized argument %s\n" % method)
                help()
                sys.exit(1)
        else:
            help()
            sys.exit(1)
        method()
