#!/usr/bin/env python3

"""
Usage:
    fsm-step YOUR.MODULE [retry_id]
    fsm-step XXX.py [retry_id]

fsm-step will load custom outer module and call the function *main*

the function main should return
    `a string represents next_state`
OR
    `a tuple represents (next_state, patch_data)`
"""

import importlib
import os
import sys

from requests import Session, ReadTimeout
from tenacity import retry, retry_if_exception_type


class Api:
    def __init__(self, fsmhub):
        if fsmhub.partition("://")[0] not in ("http", "https"):
            fsmhub = "http://" + fsmhub
        self.fsmhub = fsmhub
        self.http = Session()

    def get(self, id) -> dict:
        rsp = self.http.get(f"{self.fsmhub}/{id}")
        rsp.raise_for_status()
        return rsp.json()

    @retry(retry=retry_if_exception_type(ReadTimeout))
    def lock(self, state) -> dict:
        rsp = self.http.post(f"{self.fsmhub}/lock/{state}?wait=yes", timeout=60)
        rsp.raise_for_status()
        return rsp.json()

    def transit(self, id, next_state, patch_data: dict = None):
        rsp = self.http.post(f"{self.fsmhub}/transit/{id}/{next_state}", json=patch_data)
        rsp.raise_for_status()


api = Api(os.getenv("FSMHUB") or "fsmhub")


try:
    from loguru import logger
except ImportError:
    import logging as logger


sys.path.append("")
module_name = sys.argv[1]
if module_name.endswith(".py"):
    module_name = module_name[:-3].replace("/", ".")

# you MUST define the function *main*
that_function = importlib.import_module(module_name).main


def call(kwargs: dict):
    """call logic function, use magic
    """
    c = that_function.__code__
    d = that_function.__defaults__
    argnames = c.co_varnames[:c.co_argcount]
    defaults = dict(zip(argnames[-len(d):], d)) if d else {}
    args = [kwargs.get(k, defaults.get(k)) for k in argnames]
    return that_function(*args)


def transit(id, result):
    if isinstance(result, str):
        next_state, patch_data = result, None
    else:
        next_state, patch_data = result
        assert isinstance(patch_data, dict), patch_data
    logger.debug(next_state)
    logger.debug(patch_data)
    api.transit(id, next_state, patch_data)


def main():
    """
    """
    if len(sys.argv) > 2:  # this mode is for fix or debug
        retry_id = int(sys.argv[2])
        payload = api.get(retry_id)
        logger.debug(payload)
        state = payload["state"]
        assert state.endswith(module_name) and state != module_name, state
        result = call(payload["data"])
        transit(retry_id, result)
        return

    # loop mode
    while True:
        payload = api.lock(module_name)
        logger.debug(payload)
        id = payload["id"]
        kwargs = payload["data"]
        assert payload["state"] == module_name

        try:
            result = call(kwargs)
        except Exception as e:
            logger.exception(f"{module_name}.main")
            continue

        transit(id, result)


if __name__ == '__main__':
    main()
