#!python

import typer
import os
import json
import time
import datetime
import pathlib
from typing import Optional
from yaspin import yaspin
from rich import print as richprint
from rich.console import Console
from rich.table import Table
import importlib.metadata

from quantagonia.cloud.specs_https_client import SpecsHTTPSClient
from quantagonia.cloud.cloud_runner import CloudRunner
from quantagonia.cloud.specs_https_client import JobStatus
from quantagonia.cloud.solver_log import SolverLog
from quantagonia.parser.solution_parser import SolutionParser
from quantagonia.spec_builder import MIPSpecBuilder, QUBOSpecBuilder
from quantagonia.enums import HybridSolverServers

app = typer.Typer(help="Simple CLI client for Quantagonia's cloud-based HybridSolver.")
console = Console()

API_KEY = None
SERVER = HybridSolverServers.PROD
POLL_FREQUENCY = 2

###
# Helpers that minimize code duplication
###

def _submit(
  client : SpecsHTTPSClient,
  problem_file : str,
  spec_file : Optional[str],
  tag : str,
  relative_gap : float,
  absolute_gap : float,
  timelimit : int) -> str:

  # check if file exists
  if not os.path.isfile(problem_file):
    raise Exception(f"File {problem_file} does not exist, exiting...")

  _, _, extension = pathlib.Path(problem_file).name.partition(".")
  is_mip = (extension in ["mps", "lp", "mps.gz", "lp.gz", "mps.zip", "lp.zip"])
  is_qubo = (extension in ["qubo", "qubo.zip", "qubo.gz"])

  if not is_mip and not is_qubo:
    raise Exception(f"File {problem_file} is not supported MIP or QUBO file format, exiting...")

  # create specs
  spec = {}

  builder = None
  if is_mip:
    builder = MIPSpecBuilder()
  if is_qubo:
    builder = QUBOSpecBuilder()

  if spec_file is None:
    builder.set_time_limit(timelimit)
    builder.set_relative_gap(relative_gap)
    builder.set_absolute_gap(absolute_gap)

    spec = builder.getd()
  else:
    with open(spec_file, "r") as f:
      given_specs_txt = f.read()

    given_specs_dict = json.loads(given_specs_txt)
    spec = builder.getd()
    for k, v in given_specs_dict.items():
      if k not in spec["solver_config"]:
        print(f"Given solver option key {k} does not exist. Ignore.")
        continue
      spec["solver_config"][k] = v

  # start solving job
  has_error = False
  with yaspin(text=f" Submitting job to the Quantagonia cloud...", color="yellow") as spinner:
    try:
      job_id = client.submitJob(problem_files = [problem_file], specs = [spec],
        tag = tag)

      spinner.text = f"Submitted job with ID: {job_id}"
      spinner.ok("✅")
    except Exception as e:
      spinner.text = f"Failed to submit job"
      spinner.fail("❌")

      has_error = True

  if has_error:
    exit(1)

  return job_id

###
# CLI commands
###
@app.command()
def solve(
  problem_file : str = typer.Argument(
    help="Path to optimization problem file."
  ),
  spec_file : str = typer.Option(
    None,
    help="Path to spec file. If specified, override other options."
  ),
  relative_gap : float = typer.Option(
    1e-4,
    help="Stopping criterion: relative gap"),
  absolute_gap : float = typer.Option(
    1e-6,
    help="Stopping criterion: absolute gap"),
  timelimit : int = typer.Option(
    3600,
    help="Time limit for computation"),
  tag : str = typer.Option(
    "",
    help="Tag to identify the job later."
  ),
  follow : bool = typer.Option(
    True,
    help="Whether to stream logs interactively during runtime")):
  """
  Simple solve command - upload a file (mps/qubo, each plain, gz or zip),
  set some options and solve the problem. Outputs streaming logs, solution
  and time billed.
  """

  has_error = False

  runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
  client = runner.httpsClient()

  # submit job
  job_id = _submit(client, problem_file, spec_file, tag, absolute_gap,
    relative_gap, timelimit)

  # if follow is set, keep running and follow logs
  solver_log = SolverLog()
  solver_log.nextTimeAddNewLine()
  with yaspin(text=f" Waiting for job {job_id} to finish...", color="yellow") as spinner:
    try:
      status = JobStatus.created

      while status in [JobStatus.created, JobStatus.running]:
        status = client.getCurrentStatus(job_id)
        status = JobStatus(status[0])

        if follow:
          logs = client.getCurrentLog(job_id)
          with spinner.hidden():
            solver_log.updateLog(logs[0])

        time.sleep(POLL_FREQUENCY)

      spinner.text = f"Status: {status.value}"

      if status == JobStatus.created:
        spinner.ok("⏳")
      elif status == JobStatus.running:
        spinner.ok("💻")
      elif status in [JobStatus.finished, JobStatus.success]:
        spinner.ok("✅")
      else:
        spinner.fail("❌")
    except Exception as e:
      spinner.text = f"Failed to retrieve status"
      spinner.fail("❌")

      has_error = True

    if has_error:
      exit(1)

  # retrieve solution vector
  with yaspin(text=f" Retrieving solution for job {job_id}...", color="yellow") as spinner:
    try:
      res = client.getResults(job_id)
      res = res[0]

      spinner.text = f"Retrieved solution:"
      spinner.ok("✅")

    except Exception as e:
      spinner.text = f"Failed to retrieve solution."
      spinner.fail("❌")

      has_error = True

  if has_error:
    exit(1)

  print(SolutionParser.parse(res["solution_file"]))

  # display billed time
  print(f"✅ Minutes billed: {res['time_billed']}")

@app.command()
def submit(
  problem_file : str = typer.Argument(
    help="Path to optimiation problem file."
  ),
  spec_file : str = typer.Option(
    None,
    help="Path to spec file. If specified, override other options."
  ),
  relative_gap : float = typer.Option(
    1e-4,
    help="Stopping criterion: relative gap"),
  absolute_gap : float = typer.Option(
    1e-6,
    help="Stopping criterion: absolute gap"),
  timelimit : int = typer.Option(
    3600,
    help="Time limit for computation"),
  tag : str = typer.Option(
    "",
    help="Tag to identify the job later."
  ),):
  """
  Submit the given problem in a non-blocking way. Use status, logs, and solution commands to get results.
  """

  runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
  client = runner.httpsClient()

  return _submit(client, problem_file, spec_file, tag, relative_gap, absolute_gap,
    timelimit)

@app.command()
def logs(
  job_id : str,
  item : int = 0,
  follow : bool = typer.Option(
    True,
    help="Whether to stream logs interactively during runtime")):
  """
  Print the current logs of a given job or follow the stream of its log entries.
  For batched jobs, the optional parameter 'item' selects the item of the batch.
  """
  runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
  client = runner.httpsClient()

  if not follow:
    # single-shot log dumping
    with yaspin(text=f" Retrieving logs for job {job_id}...", color="yellow") as spinner:
      try:
        logs = client.getCurrentLog(job_id)
        spinner.text = f"Retrieved logs:"
        spinner.ok("✅")

        with spinner.hidden():
          for line in logs[item].split("\n"):
            print(line)
      except Exception as e:
        spinner.text = f"Failed to retrieve logs"
        spinner.fail("❌")

  else:
    # like in the script client, keep running and follow logs
    solver_log = SolverLog()

    while True:
      logs = client.getCurrentLog(job_id)
      solver_log.updateLog(logs[item])

      if not JobStatus(client.getCurrentStatus(job_id)[item]) in [JobStatus.running, JobStatus.created]:
        break
      time.sleep(POLL_FREQUENCY)

@app.command()
def status(
  job_id : str,
  item : int = 0
):
  """
  Retrieves the status of a given job (one of CREATED, RUNNING, FINISHED, SUCCESS,
  TIMEOUT, TERMINATED, ERROR).
  For batched jobs, the optional parameter 'item' selects the item of the batch.
  """

  with yaspin(text=f"Retrieving status for job {job_id}...", color="yellow") as spinner:
    runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
    client = runner.httpsClient()

    try:
      status = client.getCurrentStatus(job_id)
      status = JobStatus(status[item])

      spinner.text = f"Status: {status.value}"

      if status == JobStatus.created:
        spinner.ok("⏳")
      elif status == JobStatus.running:
        spinner.ok("💻")
      elif status in [JobStatus.finished, JobStatus.success]:
        spinner.ok("✅")
      else:
        spinner.ok("❌")
    except Exception as e:
      spinner.text = f" Failed to retrieve status"
      spinner.fail("❌")

@app.command()
def solution(
  job_id : str,
  item : int = 0):
  """
  Display the solution (vector) for a given job if its computation completed
  with success.
  For batched jobs, the optional parameter 'item' selects the item of the batch.
  """

  with yaspin(text=f"Retrieving solution for job {job_id}...", color="yellow") as spinner:
    runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
    client = runner.httpsClient()

    try:
      res = client.getResults(job_id)
      res = res[item]

      spinner.text = f"Retrieved solution:"
      spinner.ok("✅")

      print(SolutionParser.parse(res["solution_file"]))
    except Exception as e:
      spinner.text = "Failed to retrieve solution."
      spinner.fail("❌")

@app.command()
def time_billed(
  job_id : str):
  """
  Output the time billed for a particular job in minutes.
  """

  with yaspin(text=f"Retrieving billing time for job {job_id}...", color="yellow") as spinner:
    runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
    client = runner.httpsClient()

    try:
      res = client.getResults(job_id)
      res = res[0]

      spinner.text = f"Minutes billed: {res['time_billed']}"
      spinner.ok("✅")
    except Exception as e:
      spinner.text = "Failed to retrieve billing time."
      spinner.fail("❌")

@app.command()
def list(
  n : int = typer.Option(
    10,
    help="Maximum number of jobs to display")):
  """
  Shows a list of the user's n latest jobs with some basic information.
  """

  res = None
  with yaspin(text=f" Retrieving latest {n} jobs for given API key...", color="yellow") as spinner:
    runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
    client = runner.httpsClient()

    try:
      res = client.getJobs(n)

      spinner.text = f"You have {len(res['running'])} active and {len(res['old'])} finished jobs" + \
        (":" if len(res["running"]) + len(res["old"]) > 0 else ".")
      spinner.ok("✅")

      # output jobs in a nice, tabular format
      def jobs2table(jobs, title=""):
        tbl = Table(title=f"[bold]{title}[/bold]", title_justify="left")
        tbl.add_column()
        tbl.add_column("Job ID")
        tbl.add_column("Size", justify="right")
        tbl.add_column("Tag"),
        tbl.add_column("Type(s)")
        tbl.add_column("Filename(s)")
        tbl.add_column("Created")
        tbl.add_column("Time billed", justify="right")

        for job in jobs:
          bs = int(job["batch_size"])
          dt = datetime.datetime.fromtimestamp(int(job["created"]))
          status = ""
          if bool(job["finished"]) and bool(job["successful"]):
            status = "[green]✔[/green]"
          elif bool(job["finished"]) and not bool(job["successful"]):
            status = "[red]✗[/red]"

          tbl.add_row(
            status,
            job["job_id"],
            f"{bs}",
            "---" if job["tag"] == "" else job["tag"],
            job["first_type"] + (f" (+ {bs - 1})" if bs > 1 else ""),
            job["first_filename"] + (f" (+ {bs - 1})" if bs > 1 else ""),
            dt.strftime("%d.%m.%Y %H:%M:%S"),
            job["time_billed"])
        console.print(tbl)

      if len(res["running"]) > 0:
        print("")
        jobs2table(res["running"], title="Active jobs")

      if len(res["old"]) > 0:
        print("")
        jobs2table(res["old"], title="Finished jobs")

    except Exception as e:
      spinner.text = "Failed to retrieve list of jobs."
      spinner.fail("❌")

@app.command()
def cancel(
  job_id : str):
  """
  Cancel a job that is currently running.
  """

  with yaspin(text=f"Canceling job {job_id}...", color="yellow") as spinner:
    runner = CloudRunner(api_key = API_KEY, server = SERVER, suppress_log = True)
    client = runner.httpsClient()

    try:
      client.interruptJob(job_id)

      spinner.text = f"Job cancelled"
      spinner.ok("✅")
    except Exception as e:
      spinner.text = f"Failed to cancel job"
      spinner.fail("❌")

@app.command()
def api_key():
  """
  Prints the API key set through QUANTAGONIA_API_KEY.
  """
  print(API_KEY)

@app.command()
def version():
  """
  Prints the API client's version;
  """
  __version__ = importlib.metadata.version("quantagonia-api-client")
  print(__version__)

if __name__ == "__main__":

  if "QUANTAGONIA_API_KEY" not in os.environ:
    raise Exception("Please set QUANTAGONIA_API_KEY environment variable")

  print("")
  API_KEY = os.environ["QUANTAGONIA_API_KEY"]
  app()
