#!python

import io
import re
from os import environ
from os.path import join, isfile, realpath, expanduser
from gnureadline import parse_and_bind, set_completer
from argparse import ArgumentParser
from contextlib import contextmanager
from signal import signal, SIGINT, SIGALRM, setitimer, ITIMER_REAL
from textwrap import dedent
from glob import glob

from lark import Lark
from lark.visitors import Interpreter
from gpt4all import GPT4All

from gpt4all_cli import *

REVISION:str|None
try:
  REVISION=environ["GPT4ALLCLI_REVISION"]
except Exception:
  try:
    from subprocess import check_output, DEVNULL
    REVISION=check_output(['git', 'rev-parse', 'HEAD'],
                          cwd=environ['GPT4ALLCLI_ROOT'],
                          stderr=DEVNULL).decode().strip()
  except Exception:
    try:
      from gpt4all_cli.revision import REVISION as __rv__
      REVISION = __rv__
    except ImportError:
      REVISION = None

def _get_cmd_prefix():
  prefices = list({c[0] for c in COMMANDS if len(c)>0})
  assert len(prefices) == 1
  return prefices[0]

CMDPREFIX = _get_cmd_prefix()

def read_multiline_input(initial_prompt=">>> ", intermediate_prompt="... "):
  lines = []
  prompt = initial_prompt
  while True:
    try:
      line = input(prompt)
      lines.append(line)
      if line=='' or (len(line)>0 and line[0] in CMDPREFIX):
        break
      prompt = intermediate_prompt
    except EOFError:
      break
  return lines

ARG_PARSER = ArgumentParser(description="Command-line arguments")
ARG_PARSER.add_argument(
  "--model-dir",
  type=str,
  help="Model directory to prepend to model file names",
  default=None
)
ARG_PARSER.add_argument(
  "--model", "-m",
  type=str,
  help="Model to use for chatbot",
  # default="mistral-7b-instruct-v0.1.Q4_0.gguf",
  # default='/home/grwlf/.local/share/nomic.ai/GPT4All/Meta-Llama-3-8B-Instruct.Q4_0.gguf'
  default=None
)
ARG_PARSER.add_argument(
  "--num-threads", "-t",
  type=int,
  help="Number of threads to use for chatbot",
  default=None
)
ARG_PARSER.add_argument(
  "--device", "-d",
  type=str,
  help="Device to use for chatbot, e.g. gpu, amd, nvidia, intel. Defaults to CPU.",
  default=None
)
ARG_PARSER.add_argument(
  "--readline-key-send",
  type=str,
  help="Terminal code to treat as Ctrl+Enter (default: \\C-k)",
  default="\\C-k"
)
ARG_PARSER.add_argument(
  '--readline-prompt',
  type=str,
  help="Input prompt (default: >>>)",
  default=">>> "
)
ARG_PARSER.add_argument(
  '--revision',
  action='store_true',
  help="Print the revision",
)



@contextmanager
def with_sigint(_handler):
  """ A Not very correct singal handler. One also needs to mask signals during switching handlers """
  prev=signal(SIGINT,_handler)
  try:
    yield
  finally:
    signal(SIGINT,prev)


def ask1(gpt4all, message:str) -> str|None:
  response = io.StringIO()
  break_request = False
  try:
    def _signal_handler(signum,frame):
      nonlocal break_request
      print("\n<Keyboard interrupt>")
      break_request = True

    def _model_callback(*args, **kwargs):
      return not break_request

    response_generator = gpt4all.generate(
      message,
      # preferential kwargs for chat ux
      max_tokens=200,
      temp=0.9,
      top_k=40,
      top_p=0.9,
      min_p=0.0,
      repeat_penalty=1.1,
      repeat_last_n=64,
      n_batch=9,
      # required kwargs for cli ux (incremental response)
      streaming=True,
      callback=_model_callback
    )

    with with_sigint(_signal_handler):
      for token in response_generator:
        print(token, end='', flush=True)
        response.write(token)

  finally:
    response.close()
    print()
  return response

PARSER = Lark(GRAMMAR, start='start')


def test_parser():
  def _assert(a, tb):
    ta = re.sub(r"^[ \t]+$", '', PARSER.parse(a).pretty().strip().replace('\t', ' '*7), flags=re.MULTILINE)
    tb = re.sub(r"^[ \t]+$", '', dedent(tb).strip().replace('\t', ' '*7), flags=re.MULTILINE)
    assert ta == tb, f"\nExpected:\n{tb}\nGot:\n{ta}"

  _assert(r'\a', r'''
    start
      escape       \a
  ''')
  _assert('/echo ', '''
    start
      command       /echo
      text
  ''')
  _assert('/echo', '''
    start
      command       /echo
  ''')
  _assert('a/echo', '''
    start
      text       a
      command       /echo
  ''')
  _assert('/echoa', '''
    start
      command       /echo
      text       a
  ''')
  _assert(r'\/echo', r'''
    start
      escape       \/
      text       echo
  ''')
  _assert(r'/echo/echo', r'''
    start
      command       /echo
      command       /echo
  ''')
  _assert('/echo/echoxx', r'''
    start
      command       /echo
      command       /echo
      text       xx
  ''')
  _assert('/model "aaa"', r'''
    start
      command
        /model

        string       "aaa"
  ''')
  _assert(r'/echo\a', r'''
    start
      command       /echo
      escape       \a
  ''')
  _assert(r'', r'''
    start
  ''')
  _assert(r'/nthreads 3', r'''
    start
      command
        /nthreads

        number       3
  ''')

def print_help():
  print(f"Commands: {' '.join(COMMANDS)}")

@contextmanager
def chat_session(gpt4all):
  if gpt4all is None:
    yield
  else:
    with gpt4all.chat_session():
      yield

def model_locations(args, model)->list[str]:
  for path in [model] + ([join(args.model_dir,model)] if args.model_dir else []):
    for match in glob(expanduser(path)):
      yield realpath(match)

def main(cmdline=None):
  args = ARG_PARSER.parse_args(cmdline)
  if args.revision:
    print(REVISION)
    return 0

  test_parser()

  parse_and_bind('tab: complete')
  parse_and_bind(f'"{args.readline_key_send}": "{CMD_ASK}\n"')
  hint = args.readline_key_send.replace('\\', '')
  print(f"Type /help or a question followed by the /ask command (or by pressing "
        f"`{hint}` key).")
  old_model = None
  gpt4all = None

  def _apply():
    nonlocal gpt4all, old_model
    if old_model != args.model:
      if args.model is not None:
        gpt4all = GPT4All(args.model, device=args.device)
      else:
        if gpt4all:
          del gpt4all
        gpt4all = None
      old_model = args.model
    if gpt4all is not None:
      if args.num_threads != gpt4all.model.thread_count():
        if args.num_threads is not None:
          gpt4all.model.set_thread_count(args.num_threads)
          print(f"Num threads:", gpt4all.model.thread_count())

  class Repl(Interpreter):
    def __init__(self):
      self.in_echo = False
      self.message = ""
      self.exit_request = False
      self.reset_request = False
    def _finish_echo(self):
      if self.in_echo:
        print()
      self.in_echo = False
    def command(self, tree):
      self._finish_echo()
      command = tree.children[0].value
      if command == CMD_ECHO:
        self.in_echo = True
      elif command == CMD_ASK:
        if gpt4all is None:
          raise RuntimeError("No model is active, use /model first")
        ask1(gpt4all, self.message)
        self.message = ""
      elif command == CMD_HELP:
        print_help()
      elif command == CMD_EXIT:
        self.exit_request = True
      elif command == CMD_MODEL:
        matched = False
        argument = tree.children[2].children[0].value[1:-1]
        for model in model_locations(args, argument):
          if isfile(model):
            args.model = model
            self.reset_request = True
            matched = True
            break
        if not matched:
          args.model = argument if len(argument)>0 else None
        self.reset_request = True
      elif command == CMD_NTHREADS:
        args.num_threads = int(tree.children[2].children[0].value)
        self.reset_request = True
      elif command == CMD_RESET:
        self.reset_request = True
      else:
        raise ValueError(f"Unknown command: {command}")
    def text(self, tree):
      text = tree.children[0].value
      if self.in_echo:
        print(text, end='')
      else:
        self.message += text
    def escape(self, tree):
      text = tree.children[0].value[1:]
      if self.in_echo:
        print(text, end='')
      else:
        self.message += text

  repl = Repl()
  while not repl.exit_request:
    _apply()
    # assert gpt4all is not None, "Model should be set using cmdline args"
    repl.reset_request = False
    with chat_session(gpt4all):
      try:
        while all([not repl.exit_request, not repl.reset_request]):
          repl.visit(PARSER.parse(input(args.readline_prompt)))
          repl._finish_echo()
      except (ValueError,RuntimeError) as err:
        print(err)
      except EOFError:
        print()
        break

if __name__ == "__main__":
  main()
