#!/usr/bin/env python
import os
import sys
import yaml
import ast
import argparse
import pprint

import libem
from libem.match.parameter import tools


def add_common_arguments(parser):
    # calibrate any parameters
    parser.add_argument('-c', '--calibrate', metavar="KEY=VALUE",
                        type=parse_key_value_pair, nargs='+',
                        help="Key-value pairs separated by '=', "
                             "e.g., key1=value1 key2=value2")

    # shortcuts for calibration
    parser.add_argument("-m", "--model", dest='model', nargs='?',
                        help="The LLM to use for model call.",
                        type=str, default=libem.parameter.model())
    parser.add_argument("--temperature", dest='temperature', nargs='?',
                        help="The temperature to use for model call.",
                        type=float, default=libem.parameter.temperature())
    parser.add_argument("-b", "--browse", dest='browse',
                        help="Enable the browse tool.",
                        action='store_true', default=False)
    parser.add_argument("--cot", dest='cot',
                        help="Enable chain of thought.",
                        action='store_true', default=False)
    parser.add_argument("--confidence", dest='confidence',
                        help="Report confidence score.",
                        action='store_true', default=False)
    parser.add_argument("-g", "--guess", dest='guess',
                        help="Match by guessing.",
                        action='store_true', default=False)

    # output options
    parser.add_argument("-d", "--debug", dest='debug',
                        help="Enable debug mode.",
                        action='store_true', default=False)
    parser.add_argument("-s", "--stats", dest='stats',
                        help="Print stats from trace.",
                        action='store_true', default=False)


def chat(args):
    _ = args
    print("Welcome to the Libem Chatbot.")
    print("Type 'quit' to quit the chatbot.")

    context = []
    while True:
        user_input = input("libem > ").strip().lower()
        if user_input == 'quit':
            sys.exit()

        response = libem.chat(user_input, context)

        print("libem >", response["content"])

        context = response["context"]


def configure(args):
    _ = args

    config_path = os.path.expanduser('~/.libem/config.yaml')
    os.makedirs(os.path.dirname(config_path), exist_ok=True)

    # Load the existing configuration if it exists
    if os.path.exists(config_path):
        with open(config_path, 'r') as file:
            config = yaml.safe_load(file) or {}
    else:
        config = {}

    # Prompt for OPENAI_API_KEY
    existing_key = config.get('OPENAI_API_KEY', '')
    new_key = input(f"Enter OPENAI_API_KEY (press Enter to keep the existing key: "
                    f"'{mask_key(existing_key)}'): ").strip()

    # If no input, keep the existing key; otherwise, update
    if new_key:
        config['OPENAI_API_KEY'] = new_key
    else:
        print("No input provided. Using existing API key.")

    ...

    # Save the updated configuration
    with open(config_path, 'w') as file:
        yaml.safe_dump(config, file)


def entry():
    parser = argparse.ArgumentParser(description="Libem CLI tool")
    subparsers = parser.add_subparsers(dest='command', help='Available commands:')

    # chat interface
    chat_parser = subparsers.add_parser('chat', help='Start a Libem chatbot')
    add_common_arguments(chat_parser)

    # libem configuration
    configure_parser = subparsers.add_parser('configure', help='Set Libem configurations')
    _ = configure_parser

    """Libem tools"""
    # match
    match_parser = subparsers.add_parser('match', help='Perform entity matching')
    match_parser.add_argument('e1', type=str, help='First entity')
    match_parser.add_argument('e2', type=str, help='Second entity')
    add_common_arguments(match_parser)

    # extract
    extract_parser = subparsers.add_parser('extract', help='Perform entity extraction')
    extract_parser.add_argument('desc', type=str, help='Description')
    add_common_arguments(extract_parser)

    ...

    args = parser.parse_args()

    match args.command:
        case 'chat':
            chat(args)
        case 'configure':
            configure(args)
        case 'match':
            match(args)
        case 'extract':
            extract(args)
        case _:
            parser.print_help()


def match(args):
    configs = args.calibrate or []
    if args.model:
        configs.append(('libem.match.parameter.model', args.model))
    if args.temperature:
        configs.append(('libem.match.parameter.temperature', args.temperature))
    if args.cot:
        configs.append(('libem.match.parameter.cot', True))
    if args.confidence:
        configs.append(('libem.match.parameter.confidence', True))
    if args.browse:
        configs.append(('libem.match.parameter.tools', tools + ['libem.browse']))
    if args.guess:
        configs.append(('libem.match.parameter.guess', True))
    if args.debug:
        libem.debug_on()

    if configs:
        libem.calibrate(dict(configs))

    with libem.trace as t:
        result = libem.match(args.e1, args.e2)

    if args.cot:
        print("Explanation:\n", result['explanation'])
        print()

    print("Match:", result['answer'])

    if args.confidence:
        print("Confidence:", result['confidence'])

    if args.stats:
        pp = pprint.PrettyPrinter(sort_dicts=False)
        pp.pprint(t.stats())


def extract(args):
    with libem.trace as t:
        result = libem.extract(args.desc)

    print("Entities:", result)


def parse_key_value_pair(arg_value):
    """ Parse a key-value pair, separated by '=' """
    if '=' not in arg_value:
        raise argparse.ArgumentTypeError("Key-value pairs must be separated by '='")
    key, value = arg_value.split('=', 1)
    try:
        # Safely evaluate the value part to handle data structures like lists
        value = ast.literal_eval(value)
    except (SyntaxError, ValueError):
        # If evaluation fails, keep the value as a string
        pass
    return key, value


def mask_key(key):
    if len(key) > 7:
        return key[:3] + '*' * (len(key) - 7) + key[-4:]
    else:
        return key  # If the key is too short to fully mask, return it as is


if __name__ == '__main__':
    entry()
