#!/usr/bin/env python3
""""
Simple utility for generating version strings based on the information in the git index.

Supports three modes: parts, local, and dev.

parts
    Display all available version parts.
    Format: {timestamp: str, short_hash: str, closest_tag: str, distance: int}

    $ version parts | jq
    {
        "timestamp": "1709759768",
        "closest_tag": "0.2.0",
        "distance": 4,
        "short_hash": "4e2b423"
    }

local
    Generate a PEP 440 compliant local version tag.
    Format: {last_tag}+{num commits}.{commit hash}

    $ version local
    0.0.0+50.440f9b6

dev: Generate a PEP 440 compliant development prerelease version tag.
    Format: {last_tag + 1}.dev{timestamp}

    $ version dev
    0.0.1.dev1663305183

Implementation based on versioneer but is unlikely to be robust to old git versions and
tagging schemes that deviate from PEP 440.
"""
import re
import subprocess
import sys


def panic(message):
    print(message, file=sys.stderr)
    exit(1)


def run(command):
    return subprocess.check_output(command).decode().strip()


def get_version_parts():
    git_describe = run(["git", "describe", "--tags", "--always", "--long"])
    commit_timestamp = run(["git", "show", "-s", "--format=%ct", "HEAD"])

    parts = {}

    parts["timestamp"] = commit_timestamp

    if "-" in git_describe:
        # TAG-NUM-gHEX
        match = re.match(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
        if not match:
            panic()

        # tag
        parts["closest_tag"] = match.group(1)

        # distance: number of commits since tag
        parts["distance"] = int(match.group(2))

        # commit: short hex revision ID
        parts["short_hash"] = match.group(3)

    else:
        # HEX: no tags
        parts["short_hash"] = git_describe

        parts["closest_tag"] = "0.0.0"

        git_rev_list = run(["git", "rev-list", "HEAD", "--left-right"])
        # total number of commits
        parts["distance"] = len(git_rev_list.split())

    return parts


if __name__ == "__main__":
    if not len(sys.argv) > 1:
        panic("Missing mode. Expected one of 'parts', 'local', 'dev'.")

    mode = sys.argv[1].lower()

    parts = get_version_parts()

    if mode == "parts":
        import json

        print(json.dumps(parts))

    elif mode == "local":
        print("{closest_tag}+{distance}.{short_hash}".format(**parts))

    elif mode == "dev":
        # bump patch on closet tag
        tag_parts = parts["closest_tag"].split(".")
        tag_parts[-1] = str(int(tag_parts[-1]) + 1)
        patch_tag = ".".join(tag_parts)

        print("{}.dev{timestamp}".format(patch_tag, **parts))

    else:
        panic(f"Invalid mode {mode!r}. Expected one of 'parts', 'local', 'dev'.")
