#!python
import sys
import os
import argparse
import asyncio
import pprint
from hgdb import HGDBClient, DebugSymbolTable
from prompt_toolkit import PromptSession
from prompt_toolkit.patch_stdout import patch_stdout
import pygments
import pygments.formatters
import pygments.lexers
from prompt_toolkit.auto_suggest import AutoSuggest, Suggestion
from prompt_toolkit.key_binding import KeyBindings


def get_arguments():
    parser = argparse.ArgumentParser("hgdb debugger")
    parser.add_argument("--db", "-i", dest="symbol_table", type=str, required=True, help="Filename to the symbol table")
    parser.add_argument("--port", "-p", dest="port", type=int, help="Port number for the simulator", default=8888)
    parser.add_argument("--no-db-connection", dest="no_db_connection", action="store_true",
                        help="If present, will not inform the simulator about the debug symbol table", default=False)
    args = parser.parse_args()
    return args.symbol_table, args.port, args.no_db_connection


def index_filenames(filenames):
    result = {}
    conflicted = set()
    for filename in filenames:
        result[filename] = filename
        # check if we can do short hand
        basename = os.path.basename(filename)
        if basename in conflicted:
            continue
        if basename in result:
            # detecting conflicts
            result.pop(basename)
            conflicted.add(basename)
            continue
        result[basename] = filename
    shorten = {}
    for name, value in result.items():
        if name != value:
            # it has been shortened
            shorten[value] = name
    return result, shorten


async def get_client(filename, port, skip_connection=False):
    uri = "ws://localhost:{0}".format(port)
    # create a client and try to connect
    client = HGDBClient(uri, None if skip_connection else filename)
    # connect
    await client.connect()
    # use single thread mode
    await client.change_option(single_thread_mode=True)
    payload = await client.get_info("status")
    print(payload["payload"]["status"], end="")
    return client


def read_line(filename, line_num, file_context_cache, line_range=0):
    # notice the +1 line number conversion!
    if not os.path.exists(filename):
        return None
    else:
        if filename not in file_context_cache:
            with open(filename) as f:
                lines = f.readlines()
        else:
            lines = file_context_cache[filename]
        result = []
        min_line = max(1, line_num - line_range)
        max_line = min(len(lines), line_num + line_range)
        for i in range(min_line - 1, max_line):
            result.append((i, lines[i].rstrip()))
        return result


def print_line(filename, line_num, line, formatter):
    lexer = pygments.lexers.get_lexer_for_filename(filename)
    print(line_num + 1, pygments.highlight(line, lexer, formatter), end="")


def get_fn_ln_cn(filename, line_num, column_num, shorten_filename_map):
    if len(filename) == 0:
        return ""
    if filename in shorten_filename_map:
        filename = shorten_filename_map[filename]
    if column_num == 0:
        fn_ln_info = "{0}:{1}".format(filename, line_num)
    else:
        fn_ln_info = "{0}:{1}:{2}".format(filename, line_num, column_num)
    return fn_ln_info


class AutoSuggestFromHistoryPath(AutoSuggest):
    def __init__(self, path):
        self.path = path

    def get_suggestion(
            self, buffer, document):
        history = buffer.history
        text = document.text.rsplit("\n", 1)[-1]
        if text.strip():
            # match history first
            for string in reversed(list(history.get_strings())):
                for line in reversed(string.splitlines()):
                    if line.startswith(text):
                        return Suggestion(line[len(text):])
            tokens = text.split(" ")
            command = tokens[0]
            if len(tokens) > 1 and command in {"list", "l", "b", "breakpoint"}:
                command = command + " "
                path = text[text.index(command) + len(command):]
                if path:
                    for filename in self.path:
                        if filename.startswith(path):
                            return Suggestion(filename[len(path):])
        return None


def fix_local_str_num(target):
    if not isinstance(target, dict):
        if isinstance(target, str) and target.isdigit():
            return int(target)
        else:
            return target
    # only if all the keys are there and are numerical
    keys = set(target.keys())
    for i in range(len(target)):
        if str(i) not in keys:
            # skip this layer
            for key, value in target.items():
                target[key] = fix_local_str_num(value)
            return target
    result = [None for _ in range(len(target))]
    for key, value in target.items():
        result[int(key)] = fix_local_str_num(value)
    return result


def parse_local(local_vars):
    result = {}
    # notice that everything is flat from the debug server
    for name, value in local_vars.items():
        # replace [] with .
        name = name.replace("[", ".").replace("]", "")
        tokens = name.split(".")
        if len(tokens) == 1:
            result[name] = value
            continue
        # build up the hierarchy
        target = result
        for i in range(len(tokens) - 1):
            if tokens[i] not in target:
                target[tokens[i]] = {}
            target = target[tokens[i]]
        target[tokens[-1]] = value if not value.isdigit() else int(value)

    # second pass recursively change the map into array
    result = fix_local_str_num(result)
    return result


def locate_local_vars(expr, local_vars):
    if len(expr.split()) > 1:
        return None
    name = expr.replace("[", ".").replace("]", "")
    tokens = name.split(".")
    target = local_vars
    for var_name in tokens:
        key = int(var_name) if var_name.isdigit() else var_name
        if key not in target:
            return None
        target = target[key]
    return target


async def main_loop(client: HGDBClient, db, session, formatter):
    filename_map, shortened_map = index_filenames(db.get_filenames())
    # list of commands
    current_scope = ""
    commands = {}
    commands_help = {}
    file_context_cache = {}
    current_breakpoint_fn = ""
    current_breakpoint_ln = 0
    current_breakpoint_cn = 0
    # for local vars
    local_vars = {}

    def check_error(resp):
        if resp["status"] == "error":
            print(resp["payload"]["reason"])
            return False
        return True

    async def print_var(expr):
        target = locate_local_vars(expr, local_vars)
        if target is not None:
            pprint.pprint(target)
            return
        resp = await client.evaluate(current_scope, expr, False)
        if check_error(resp):
            print(resp["payload"]["result"])

    def parse_fn_ln(expr):
        tokens = expr.split(":")
        if len(tokens) < 2 or len(tokens) > 3:
            return None
        filename = tokens[0]
        if filename not in filename_map:
            return
        filename = filename_map[filename]
        if not tokens[1].isdigit():
            return None
        line_num = int(tokens[1])
        column_num = 0
        if len(tokens) == 3:
            if not tokens[2].isdigit():
                return None
            column_num = int(tokens[2])
        return filename, line_num, column_num

    async def insert_breakpoint(expr):
        r = parse_fn_ln(expr)
        if r is None:
            print("Invalid breakpoint", expr)
        else:
            filename, line_num, column_num = r
            resp = await client.set_breakpoint(filename, line_num, column_num, check_error=False)
            check_error(resp)

    async def remove_breakpoint(expr):
        if not expr.isdigit():
            print("Unable to parse breakpoint id", expr)
        bp_id = int(expr)
        resp = await client.remove_breakpoint_id(bp_id, False)
        check_error(resp)

    async def clear_breakpoint(expr):
        r = parse_fn_ln(expr)
        if r is None:
            print("Unable to parse breakpoint", expr)
            return
        fn, ln, cn = r
        await client.remove_breakpoint(fn, ln, cn)

    async def continue_(_):
        await client.continue_()
        return True

    async def step_over(_):
        await client.step_over()
        return True

    async def step_back(_):
        await client.step_back()
        return True

    async def list_file(expr):
        r = parse_fn_ln(expr)
        if r is None:
            print("Invalid filename at", expr)
        else:
            filename, line_num, _ = r
            # need to read out the filename given these lines
            lines = read_line(filename, line_num, file_context_cache, 5)
            for line_num, line in lines:
                print_line(filename, line_num, line, formatter)

    async def info(expr):
        tokens = expr.split(" ")
        if len(tokens) != 1:
            print("Invalid information command", expr)
            return
        else:
            cmd = tokens[0]
            if cmd == "b" or cmd == "breakpoint":
                # need to get the current inserted breakpoints
                resp = await client.get_info("breakpoints")
                for bp in resp["payload"]["breakpoints"]:
                    print("{0:8} {1}".format(bp["id"], get_fn_ln_cn(bp["filename"], bp["line_num"], bp["column_num"],
                                                                    shortened_map)))
            else:
                print("Invalid info command", expr)

    async def exit_(_):
        print("exit")
        await client.close()
        exit(0)

    async def help_(expr):
        if len(expr) == 0:
            # print out the available commands
            commands_str = " ".join([n for n in commands])
            print("List of available commands:\033[1m", commands_str, '\033[0m')
            print('Type "help" followed by a command to see detailed explanation')
        else:
            if expr not in commands:
                print("Invalid command", expr, "for help")
                return
            fn = commands[expr]
            help_str = commands_help[fn]
            print('\033[1m' + expr + '\033[0m', "--", help_str)

    def parse_breakpoint_info(resp):
        payload = resp["payload"]
        filename = payload["filename"]
        line_num = payload["line_num"]
        column_num = payload["column_num"]
        # we are in the single thread mode
        instance = payload["instances"][0]
        instance_id = instance["instance_id"]
        breakpoint_id = instance["breakpoint_id"]
        local_vars_ = parse_local(instance["local"])
        return (filename, line_num, column_num), instance_id, breakpoint_id, local_vars_

    def render_breakpoint(bp_info, breakpoint_id):
        (filename, line_num, column_num) = bp_info
        fn_ln_info = get_fn_ln_cn(filename, line_num, column_num, shortened_map)
        print("Breakpoint", breakpoint_id, "at", fn_ln_info)
        line_text = read_line(filename, line_num, file_context_cache)
        if line_text is not None and len(line_text) > 0:
            line_num, line = line_text[0]
            print_line(filename, line_num, line, formatter)

    # registers commands
    commands = {"c": continue_, "continue": continue_, "s": step_over, "step": step_over, "b": insert_breakpoint,
                "p": print_var, "l": list_file, "list": list_file, "i": info, "info": info, "d": remove_breakpoint,
                "delete": remove_breakpoint,  "clear": clear_breakpoint, "q": exit_, "exit": exit_, "quit": exit_,
                "stepback": step_back, "sb": step_back, "help": help_}

    commands_help = {continue_: "Continues running the simulation until the next breakpoint",
                     step_over: "Continue running the simulation until it reaches different source line",
                     step_back: "Step backward in the simulation",
                     insert_breakpoint: "Puts a breakpoint at particular line",
                     print_var: "Prints the current value of variable or expressions",
                     list_file: "Prints lines from a source file",
                     info: "Shows simulator information",
                     remove_breakpoint: "Removes breakpoint given the ID",
                     clear_breakpoint: "Clears breakpoints at particular file",
                     exit_: "Detach from the simulator",
                     help_: "Shows help menu"}

    # for tab auto-complete
    bindings = KeyBindings()

    @bindings.add('tab')
    def _(event):
        buffer = event.current_buffer
        suggestion: Suggestion = buffer.suggestion
        if suggestion is not None:
            buffer.text = buffer.text + suggestion.text
            buffer.cursor_position = len(buffer.text)

    while True:
        try:
            if sys.stdin.isatty():
                with patch_stdout():
                    result = await session.prompt_async("(hgdb) ",
                                                        auto_suggest=AutoSuggestFromHistoryPath(filename_map),
                                                        rprompt=get_fn_ln_cn(current_breakpoint_fn,
                                                                             current_breakpoint_ln,
                                                                             current_breakpoint_cn, shortened_map),
                                                        key_bindings=bindings)
            else:
                result = input("(hgdb) ")
        except KeyboardInterrupt:
            continue
        result = result.strip()
        command_tokens = result.split(" ")
        command = command_tokens[0]
        if command not in commands:
            print("Invalid command", result)
            continue
        command_arg = " ".join(command_tokens[1:]) if len(command_tokens) > 1 else ""
        command_func = commands[command]
        should_recv = await command_func(command_arg.strip())
        if should_recv:
            while True:
                info = await client.recv()
                if info is None:
                    print("Simulator exited")
                    return
                if info["type"] == "breakpoint":
                    # don't care about other stuff since it's handled properly via recv()
                    break
            line_info, i_id, b_id, local_vars = parse_breakpoint_info(info)
            current_scope = str(b_id)
            render_breakpoint(line_info, b_id)
            current_breakpoint_fn, current_breakpoint_ln, current_breakpoint_cn = line_info


def main():
    filename, port, no_db_connection = get_arguments()
    if not os.path.exists(filename):
        print("Unable to find", filename, file=sys.stderr)
        exit(1)

    # load the table locally as well
    db = DebugSymbolTable(filename)
    client: HGDBClient = asyncio.get_event_loop().run_until_complete(get_client(filename, port, no_db_connection))

    # prompt session
    session = PromptSession() if sys.stdin.isatty() else None

    # formatter
    formatter = pygments.formatters.get_formatter_by_name("terminal")

    # loop until finish
    try:
        asyncio.get_event_loop().run_until_complete(main_loop(client, db, session, formatter))
    except (KeyboardInterrupt, EOFError):
        print("exit")


if __name__ == "__main__":
    main()
