#!python
import typer
import os
import json
import time
import datetime
import pathlib
from typing import Optional
from yaspin import yaspin
from rich.console import Console
from rich.table import Table
import importlib.metadata

from quantagonia.cloud.specs_https_client import SpecsHTTPSClient
from quantagonia.cloud.cloud_runner import CloudRunner
from quantagonia.cloud.specs_https_client import JobStatus
from quantagonia.cloud.solver_log import SolverLog
from quantagonia.parser.solution_parser import SolutionParser
from quantagonia.spec_builder import MIPSpecBuilder, QUBOSpecBuilder
from quantagonia.enums import HybridSolverServers

app = typer.Typer(help="Simple CLI client for Quantagonia's cloud-based HybridSolver.")
console = Console()

API_KEY = None  # filled in main()
SERVER = None  # filled in main()
POLL_FREQUENCY = 2

###
# Helpers that minimize code duplication
###

def _exit_with_error(error_msg) -> None:
    console.print(f"[red]Error:[/red] {error_msg}")
    exit(1)


def _print_warning(msg) -> None:
  console.print(f"[yellow]Warning:[/yellow] {msg}")


def _submit(
  client : SpecsHTTPSClient,
  problem_file : str,
  spec_file : Optional[str],
  tag : str,
  relative_gap : float,
  absolute_gap : float,
  timelimit : int,
  quiet: bool) -> str:

  # check if file exists
  if not os.path.isfile(problem_file):
    _exit_with_error(f"File {problem_file} does not exist, exiting...")

  # check file extension
  _, _, extension = pathlib.Path(problem_file).name.partition(".")
  is_mip = (extension in ["mps", "lp", "mps.gz", "lp.gz"])
  is_qubo = (extension in ["qubo", "qubo.gz"])

  if not is_mip and not is_qubo:
    _exit_with_error(f"File {problem_file} is not in MIP or QUBO file format, exiting...")

  # create specs
  spec = {}

  builder = None
  if is_mip:
    builder = MIPSpecBuilder()
  if is_qubo:
    builder = QUBOSpecBuilder()

  if spec_file is None:
    builder.set_time_limit(timelimit)
    builder.set_relative_gap(relative_gap)
    builder.set_absolute_gap(absolute_gap)

    spec = builder.getd()
  else:
    try:
        with open(spec_file, "r") as f:
          given_specs_txt = f.read()
    except FileNotFoundError:
        _exit_with_error(f"Spec file {spec_file} does not exist, exiting...")

    given_specs_dict = json.loads(given_specs_txt)
    spec = builder.getd()
    for k, v in given_specs_dict.items():
      if k not in spec["solver_config"]:
        _print_warning(f"Given solver option key {k} does not exist. Ignore.")
        continue
      spec["solver_config"][k] = v

  # start solving job
  if quiet:
    try:
      job_id = client.submitJob(problem_files = [problem_file], specs = [spec],
        tag = tag)

      print(f"Submitted job with ID: {job_id}")
    except Exception as e:
      _exit_with_error("Failed to submit job")

  else:
    has_error = False
    with yaspin(text=f"Submitting job to the Quantagonia cloud...", color="yellow") as spinner:
      try:
        job_id = client.submitJob(problem_files = [problem_file], specs = [spec],
          tag = tag)

        spinner.text = f"Submitted job with ID: {job_id}"
        spinner.ok("✅")
      except Exception as e:
        spinner.text = f"Failed to submit job"
        spinner.fail("❌")
        has_error = True
    # handle exit outside of spinner
    if has_error:
      exit(1)

  return job_id


def _status(job_id : str, item : int, client : SpecsHTTPSClient, spinner : yaspin) -> bool:

  try:
    status = client.getCurrentStatus(job_id)
    status = JobStatus(status[item])

    spinner.text = f"Status: {status.value}"

    if status == JobStatus.created:
      spinner.ok("⏳")
    elif status == JobStatus.running:
      spinner.ok("💻")
    elif status in [JobStatus.finished, JobStatus.success]:
      spinner.ok("✅")
    else:
      spinner.ok("❌")
  except Exception as e:
    spinner.text = f"Failed to retrieve status of job {job_id}"
    spinner.fail("❌")
    return False

  return True


def _follow_job(job_id : str, client : SpecsHTTPSClient, solver_log = SolverLog):
  """Follow given job, i.e., print logs, final status etc."""

  # we want the exit to be outside the with statement such that the spinner closes
  has_error = False
  # if follow is set, keep running and follow logs
  with yaspin(text=f"Processing job {job_id}...", color="yellow") as spinner:

    try:
      status = JobStatus.created

      while status in [JobStatus.created, JobStatus.running]:
        status = client.getCurrentStatus(job_id)
        status = JobStatus(status[0])

        logs = client.getCurrentLog(job_id)
        with spinner.hidden():
          solver_log.updateLog(logs[0])
        time.sleep(POLL_FREQUENCY)

      spinner.text = f"Status: {status.value}"

      if status in [JobStatus.finished, JobStatus.success]:
        spinner.ok("✅")
      else:
        spinner.fail("❌")
    except Exception as e:
      spinner.text = f"Failed to retrieve status for job {job_id}"
      spinner.fail("❌")
      has_error = True

  if has_error:
    exit(1)

def _retrieve_billing_time(job_id : str, client : SpecsHTTPSClient):
  with yaspin(text=f"Retrieving billing time for job {job_id}...", color="yellow") as spinner:

    try:
      res = client.getResults(job_id)
      res = res[0]

      spinner.text = f"Minutes billed: {res['time_billed']}"
      spinner.ok("✅")
    except Exception as e:
      spinner.text = "Failed to retrieve billing time."
      spinner.fail("❌")


###
# CLI commands
###
@app.command()
def solve(
  problem_file : str = typer.Argument(
    help="Path to optimization problem file."
  ),
  spec_file : str = typer.Option(
    None,
    help="Path to spec file. If specified, override other options."
  ),
  relative_gap : float = typer.Option(
    1e-4,
    help="Stopping criterion: relative gap"),
  absolute_gap : float = typer.Option(
    1e-6,
    help="Stopping criterion: absolute gap"),
  timelimit : int = typer.Option(
    3600,
    help="Time limit for computation"),
  tag : str = typer.Option(
    "",
    help="Tag to identify the job later."
  ),
  quiet : bool = typer.Option(
    False,
    help="Disable interactive output and only show final logs")):
  """
  Submit a model in .mps or .qubo format (either plain or as .gz) and actively follow the progress. This command is equivalent to a 'submit' command followed by a 'follow'.
  """

  runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
  client = runner.httpsClient()

  # submit job
  job_id = _submit(client, problem_file, spec_file, tag, absolute_gap,
    relative_gap, timelimit, quiet)

  solver_log = SolverLog()
  solver_log.nextTimeAddNewLine()

  if quiet:
    try:
      status = JobStatus.created
      while status in [JobStatus.created, JobStatus.running]:
        status = client.getCurrentStatus(job_id)
        status = JobStatus(status[0])
        time.sleep(POLL_FREQUENCY)

      logs = client.getCurrentLog(job_id)
      solver_log.updateLog(logs[0])

    except Exception as e:
      _exit_with_error("Failed to retrieve status")

    print(f"\nFinished job with status {status.value}")

  else:

    # follow job, i.e., print logs, final status etc.
    _follow_job(job_id, client, solver_log)

    # get time billed
    _retrieve_billing_time(job_id, client)


@app.command()
def submit(
  problem_file : str = typer.Argument(
    help="Path to optimiation problem file."
  ),
  spec_file : str = typer.Option(
    None,
    help="Path to spec file. If specified, override other options."
  ),
  relative_gap : float = typer.Option(
    1e-4,
    help="Stopping criterion: relative gap"),
  absolute_gap : float = typer.Option(
    1e-6,
    help="Stopping criterion: absolute gap"),
  timelimit : int = typer.Option(
    3600,
    help="Time limit for computation"),
  tag : str = typer.Option(
    "",
    help="Tag to identify the job later."
  ),
  quiet : bool = typer.Option(
    False,
    help="Disable interactive output and only show final logs")):
  """
  Submit the given problem in a non-blocking way. Use status, logs, and solution commands to get results.
  """

  runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
  client = runner.httpsClient()

  return _submit(client, problem_file, spec_file, tag, relative_gap, absolute_gap,
    timelimit, quiet)

@app.command()
def logs(
  job_id : str,
  item : int = 0):
  """
  Print the current logs of a given job.
  For batched jobs, the optional parameter 'item' selects the item of the batch.
  """
  runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
  client = runner.httpsClient()

  # single-shot log dumping of all available lines
  has_error = False
  with yaspin(text=f"Retrieving current logs for job {job_id}...", color="yellow") as spinner:
    try:
      logs = client.getCurrentLog(job_id)
      spinner.text = f"Retrieved logs:"
      spinner.ok("✅")

      with spinner.hidden():
        for line in logs[item].split("\n"):
          print(line)
    except Exception as e:
      spinner.text = f"Failed to retrieve logs"
      spinner.fail("❌")
      has_error = True

  if has_error:
    exit(1)


@app.command()
def follow(
  job_id : str,
  item : int = 0):
  """
  Resumes following the progress (i.e., logs) of a given job, e.g., after a 'submit' command.
  For batched jobs, the optional parameter 'item' selects the item of the batch.
  """
  runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
  client = runner.httpsClient()
  solver_log = SolverLog()

  # follow job, i.e., print logs, final status etc.
  _follow_job(job_id, client, solver_log)

  # get time billed
  _retrieve_billing_time(job_id, client)


@app.command()
def status(
  job_id : str,
  item : int = 0
):
  """
  Retrieves the status of a given job (one of CREATED, RUNNING, FINISHED, SUCCESS,
  TIMEOUT, TERMINATED, ERROR).
  For batched jobs, the optional parameter 'item' selects the item of the batch.
  """

  with yaspin(text=f"Retrieving status for job {job_id}...", color="yellow") as spinner:
    runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
    client = runner.httpsClient()

    _status(job_id, item, client, spinner)


@app.command()
def solution(
  job_id : str,
  item : int = 0):
  """
  Display the solution (vector) for a given job if its computation completed
  with success.
  For batched jobs, the optional parameter 'item' selects the item of the batch.
  """

  with yaspin(text=f"Retrieving solution for job {job_id}...", color="yellow") as spinner:
    runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
    client = runner.httpsClient()

    try:
      res = client.getResults(job_id)
      res = res[item]

      spinner.text = f"Retrieved solution:"
      spinner.ok("✅")

      print(SolutionParser.parse(res["solution_file"]))
    except Exception as e:
      spinner.text = "Failed to retrieve solution."
      spinner.fail("❌")

@app.command()
def time_billed(
  job_id : str):
  """
  Output the time billed for a particular job in minutes.
  """

  runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
  client = runner.httpsClient()

  _retrieve_billing_time(job_id, client)

@app.command()
def list(
  n : int = typer.Option(
    10,
    help="Maximum number of jobs to display")):
  """
  Shows a list of the user's n latest jobs with some basic information.
  """

  res = None
  with yaspin(text=f"Retrieving latest {n} jobs for given API key...", color="yellow") as spinner:
    runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
    client = runner.httpsClient()

    try:
      res = client.getJobs(n)

      spinner.text = f"You have {len(res['running'])} active and {len(res['old'])} finished jobs" + \
        (":" if len(res["running"]) + len(res["old"]) > 0 else ".")
      spinner.ok("✅")

      # output jobs in a nice, tabular format
      def jobs2table(jobs, title=""):
        tbl = Table(title=f"[bold]{title}[/bold]", title_justify="left")
        tbl.add_column()
        tbl.add_column("Job ID")
        tbl.add_column("Size", justify="right")
        tbl.add_column("Tag"),
        tbl.add_column("Type(s)")
        tbl.add_column("Filename(s)")
        tbl.add_column("Created")
        tbl.add_column("Time billed", justify="right")

        for job in jobs:
          bs = int(job["batch_size"])
          dt = datetime.datetime.fromtimestamp(int(job["created"]))
          status = ""
          if bool(job["finished"]) and bool(job["successful"]):
            status = "[green]✔[/green]"
          elif bool(job["finished"]) and not bool(job["successful"]):
            status = "[red]✗[/red]"

          tbl.add_row(
            status,
            job["job_id"],
            f"{bs}",
            "---" if job["tag"] == "" else job["tag"],
            job["first_type"] + (f" (+ {bs - 1})" if bs > 1 else ""),
            job["first_filename"] + (f" (+ {bs - 1})" if bs > 1 else ""),
            dt.strftime("%d.%m.%Y %H:%M:%S"),
            job["time_billed"])
        console.print(tbl)

      if len(res["running"]) > 0:
        print("")
        jobs2table(res["running"], title="Active jobs")

      if len(res["old"]) > 0:
        print("")
        jobs2table(res["old"], title="Finished jobs")

    except Exception as e:
      spinner.text = "Failed to retrieve list of jobs."
      spinner.fail("❌")

@app.command()
def cancel(
  job_id : str):
  """
  Cancel a job that is currently running.
  """

  with yaspin(text=f"Canceling job {job_id}...", color="yellow") as spinner:
    runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
    client = runner.httpsClient()

    try:
      client.interruptJob(job_id)

      spinner.text = f"Job cancelled"
      spinner.ok("✅")
    except Exception as e:
      spinner.text = f"Failed to cancel job"
      spinner.fail("❌")

@app.command()
def api_key():
  """
  Prints the API key set through QUANTAGONIA_API_KEY.
  """
  print(API_KEY)

@app.command()
def version():
  """
  Prints the version of this Python package.
  """
  __version__ = importlib.metadata.version("quantagonia-api-client")
  print(__version__)

if __name__ == "__main__":

  if "QUANTAGONIA_API_KEY" not in os.environ:
    _exit_with_error("Quantagonia API Key not found. Please set the 'QUANTAGONIA_API_KEY' environment variable.")
  API_KEY = os.environ["QUANTAGONIA_API_KEY"]

  # internal users can set a server different to PROD through an env variable
  SERVER = HybridSolverServers[os.environ.get("QUANTAGONIA_SERVER", "PROD").upper()]
  if SERVER != HybridSolverServers.PROD:
      _print_warning(f"Job is submitted to {SERVER.name} environment.")

  app()
