#!/usr/bin/env python
from __future__ import print_function

import argparse
import inspect
import random
import socket
import sys
from sysadmin import SysAdminClient
from sysadmin import UnpackFromProto
from sysadmin.generated import sysadminctl_pb2


def rand_uint32():
    return random.randint(0, 2**32 - 1)


type_choices = ["str", "int32", "bool", "int32list", "boollist", "stringlist"]


def to_bool(strval):
    return strval.lower() == "true"


def convertToType(stringvalue, valuetype):
    if len(stringvalue):
        if valuetype == "str":
            return stringvalue[0]
        if valuetype == "int32":
            return int(stringvalue[0])
        if valuetype == "bool":
            return to_bool(stringvalue[0])

    def handle_lists(itemtype):
        output = []
        for item in stringvalue:
            if len(item) > 0:
                output.append(itemtype(item))
        return output
    if valuetype == "int32list":
        return handle_lists(int)
    if valuetype == "boollist":
        return handle_lists(to_bool)
    if valuetype == "stringlist":
        return handle_lists(str)


def get_existing_type(response):
    if response.status != sysadminctl_pb2.SUCCESS:
        return None
    return response.get.kvs[0].value.WhichOneof("value").replace("val", "")


def format_kvs(kvs):
    kvstrings = []
    for kv in kvs:
        if kv.value:
            value = UnpackFromProto(kv.value)
            if isinstance(value, str) or isinstance(value, unicode):
                value = "\"%s\"" % value
            kvstrings.append("%s = %s" % (kv.key, value))
    return "\n".join(sorted(kvstrings))


def format_set_output(response):
    if response.status == sysadminctl_pb2.SUCCESS_KEY_CREATED:
        return "Uncommitted created new entry\n" + format_kvs(response.get.kvs)
    elif response.status == sysadminctl_pb2.SUCCESS:
        return "Uncommitted modified entry\n" + format_kvs(response.get.kvs)
    return response


class SetCommand(object):
    def __init__(self, argparser):
        self.parser = argparser.add_parser("set")
        self.parser.add_argument("key", type=str,
                                 help="key")
        self.parser.add_argument("value", type=str, nargs="*",
                                 help="value")
        self.parser.add_argument("--type", type=str, default="str",
                                 choices=type_choices,
                                 help="value type, default: str")

    def run(self, args):
        client = SysAdminClient(args.host, args.port, args.xid)
        existing = client.get(args.key)
        existing_type = get_existing_type(existing)
        if existing_type is not None:
            pyvalue = convertToType(args.value, existing_type)
        else:
            pyvalue = convertToType(args.value, args.type)
        return format_set_output(client.set(args.key, pyvalue))


class ModifyCommand(object):
    def __init__(self, argparser):
        self.parser = argparser.add_parser("modify")
        self.parser.add_argument("key", type=str,
                                 help="key")
        self.parser.add_argument("value", type=str, nargs="*",
                                 help="value")
        self.parser.add_argument("--type", type=str, default="str",
                                 choices=type_choices,
                                 help="value type, default: str")

    def run(self, args):
        client = SysAdminClient(args.host, args.port, args.xid)
        existing = client.get(args.key)
        existing_type = get_existing_type(existing)
        if existing_type is None:
            return "Key not found: " + args.key
        if args.type != existing_type:
            return "Expected type: " + existing_type

        pyvalue = convertToType(args.value, existing_type)
        return format_set_output(client.set(args.key, pyvalue))


def format_get_output(key, response):
    if response.status == sysadminctl_pb2.SUCCESS:
        return format_kvs(response.get.kvs)
    elif response.status == sysadminctl_pb2.KEY_NOT_FOUND:
        return "Key not found: " + key
    else:
        return response


class GetCommand(object):
    def __init__(self, argparser):
        self.parser = argparser.add_parser("get")
        self.parser.add_argument("key", type=str,
                                 help="key")

    def run(self, args):
        client = SysAdminClient(args.host, args.port, args.xid)
        return format_get_output(args.key, client.get(args.key))


def format_commit_output(response):
    if response.status == sysadminctl_pb2.SUCCESS:
        return "Commit successful: " + str(response.commit.commit_id)
    return response


class CommitCommand(object):

    opts = ["default", "templates", "nohooks"]

    def __init__(self, argparser):
        self.parser = argparser.add_parser("commit")
        self.parser.add_argument("--options", type=str, default="default",
                                 choices=CommitCommand.opts,
                                 help="Run hooks, only templates, or nothing")

    def run(self, args):
        client = SysAdminClient(args.host, args.port, args.xid)
        if args.options == "templates":
            opt = sysadminctl_pb2.TEMPLATE_ONLY
        elif args.options == "nohooks":
            opt = sysadminctl_pb2.NO_HOOKS
        else:
            opt = sysadminctl_pb2.DEFAULT
        return format_commit_output(client.commit(opt))


def format_blame_output(key, response):
    # return response
    blameStrings = []
    if response.status == sysadminctl_pb2.SUCCESS:
        for blameEntry in response.blame.entries:
            blameStrings.append("\"%s\" in commit %s (%s)" % (UnpackFromProto(blameEntry.val), blameEntry.commit_id, blameEntry.commit_time))
    blameOutput = "\n".join(blameStrings)
    return blameOutput


class BlameCommand(object):
    def __init__(self, argparser):
        self.parser = argparser.add_parser("blame")
        self.parser.add_argument("key", type=str,
                                 help="key")

    def run(self, args):
        client = SysAdminClient(args.host, args.port, args.xid)
        return format_blame_output(args.key, client.blame(args.key))


class DropCommand(object):
    def __init__(self, argparser):
        self.parser = argparser.add_parser("drop")

    def run(self, args):
        client = SysAdminClient(args.host, args.port, args.xid)
        return client.drop()


class FireHooksCommand(object):
    def __init__(self, argparser):
        self.parser = argparser.add_parser("firehooks")
        self.parser.add_argument("--timeout", type=int, default=None,
                                 help="firehook timeout (in seconds)")

    def run(self, args):
        client = SysAdminClient(args.host, args.port, args.xid, args.timeout)
        return client.firehooks()


class EraseCommand(object):
    def __init__(self, argparser):
        self.parser = argparser.add_parser("erase")
        self.parser.add_argument("key", type=str,
                                 help="key")

    def run(self, args):
        client = SysAdminClient(args.host, args.port, args.xid)
        return client.erase(args.key)


class RollbackCommand(object):
    def __init__(self, argparser):
        self.parser = argparser.add_parser("rollback")
        self.parser.add_argument("id", type=int,
                                 help="Commit id being rolled back")

    def run(self, args):
        client = SysAdminClient(args.host, args.port, args.xid)
        return client.rollback(args.id)


class ResetCommand(object):
    def __init__(self, argparser):
        self.parser = argparser.add_parser("reset")

    def run(self, args):
        client = SysAdminClient(args.host, args.port, args.xid)
        return client.reset()

class DumpHooksCommand(object):
    def __init__(self, argparser):
        self.parser = argparser.add_parser("dumphooks")

    def run(self, args):
        client = SysAdminClient(args.host, args.port, args.xid)
        return client.dumphooks()


class TriggerCommand(object):
    def __init__(self, argparser):
        self.parser = argparser.add_parser("trigger")
        self.parser.add_argument("hook", type=str,
                                 help="Name of hook to trigger")

    def run(self, args):
        client = SysAdminClient(args.host, args.port, args.xid)
        return client.trigger(args.hook)


def format_diff(key, old_get, new_item):
    def format_value(value):
        value = UnpackFromProto(value)
        if isinstance(value, str) or isinstance(value, unicode):
            value = "\"%s\"" % value
        return value

    formatted_old = 'None'
    if old_get.status != sysadminctl_pb2.KEY_NOT_FOUND:
        formatted_old = format_value(old_get.get.kvs[0].value)
    formatted_new = 'Erased'
    if new_item.value:
        formatted_new = format_value(new_item.value)
    return "{}: {} -> {}".format(key, formatted_old, formatted_new)


class DiffCommand(object):
    def __init__(self, argparser):
        self.parser = argparser.add_parser("diff")

    def run(self, args):
        client = SysAdminClient(args.host, args.port, args.xid)
        inflight = client.inflight()
        output = []
        for item in inflight.inflight.kvs:
            old_item = client.get(item.key)
            output.append(format_diff(item.key, old_item, item))
        return '\n'.join(output)


def _get_commands_map():
    def is_command_class(string):
        return string.endswith('Command')
    classcommands = []
    for name, obj in inspect.getmembers(sys.modules[__name__]):
        if inspect.isclass(obj) and is_command_class(name):
            classcommands.append((name, obj))
    commands = map(lambda x: x[0].replace("Command", "").lower(), classcommands)  # noqa
    # Is there a nicer way to get tuple values?
    commandmap = {k: v for k, v in zip(commands, map(lambda x: x[1], classcommands))}  # noqa
    return commandmap


def initialize_subcommands(commands, argparser):
    for (name, command) in commands.iteritems():
        inited = command(argparser)
        commands[name] = inited


def main():
    parser = argparse.ArgumentParser(description="Send commands to sysadmin",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--host", type=str, default='127.0.0.1',
                        help='sysadmin host')
    parser.add_argument("--port", type=int, default=4000,
                        help="sysadmin port")
    parser.add_argument("--xid", type=int, default=0,
                        help="Transaction xid")
    subparser = parser.add_subparsers(dest="command")
    command_map = _get_commands_map()
    initialize_subcommands(command_map, subparser)
    args = parser.parse_args()
    try:
        print(command_map[args.command].run(args))
        return 0
    except socket.timeout as err:
        print(err.strerror, file=sys.stderr)
        return err.errno

if __name__ == '__main__':
    main()
