#!/usr/bin/env scripts/uv-run-script
# -*- mode: python -*-
# /// script
# dependencies = [
#     "packaging>=23.1,<24",
# ]
# ///
"""
Validate that all project dependencies have well-defined version ranges.

This script checks that dependencies in pyproject.toml have both lower and upper
bounds to prevent unexpected breaking changes from transitive dependencies.

Acceptable formats:
- >=X.Y,<X.Z (explicit bounds)
- ~=X.Y.Z (compatible release with at least x.y format, implies upper bound)

Unacceptable formats:
- >=X.Y (open-ended, no upper bound)
- <X.Y (only upper bound)
- ~=X (only single version component, too open-ended)
- package (no version specified)
"""

import sys
import tomllib
from pathlib import Path
from typing import Dict, List, Tuple

from packaging.specifiers import SpecifierSet


def load_pyproject() -> Dict:
    """Load and parse pyproject.toml from the current directory."""
    pyproject_path = Path("pyproject.toml")
    if not pyproject_path.exists():
        print(f"Error: {pyproject_path} not found")
        sys.exit(1)

    with open(pyproject_path, "rb") as f:
        return tomllib.load(f)


def has_lower_bound(spec_set: SpecifierSet) -> bool:
    """Check if specifier set has a lower bound."""
    for spec in spec_set:
        operator = spec.operator
        if operator in (">=", ">", "~=", "=="):
            return True
    return False


def has_upper_bound(spec_set: SpecifierSet) -> bool:
    """Check if specifier set has an upper bound."""
    for spec in spec_set:
        operator = spec.operator
        if operator in ("<", "<=", "~=", "=="):
            return True
    return False


def is_well_defined(spec_string: str) -> Tuple[bool, str]:
    """
    Check if a version specifier is well-defined.

    Returns:
        Tuple of (is_valid, reason_if_invalid)
    """
    # Handle empty string (no version specified)
    if not spec_string.strip():
        return False, "no version specified"

    try:
        spec_set = SpecifierSet(spec_string)
    except Exception as e:
        return False, f"invalid specifier: {e}"

    # Check for lower bound
    if not has_lower_bound(spec_set):
        return False, "missing lower bound (no >=, >, ~=, or ==)"

    # Check for upper bound
    if not has_upper_bound(spec_set):
        return False, "missing upper bound (no <, <=, ~=, or ==)"

    # If using compatible release (~=), ensure it has at least x.y format
    for spec in spec_set:
        if spec.operator == "~=":
            version_parts = spec.version.split(".")
            if len(version_parts) < 2:
                return False, f"compatible release ~= requires at least x.y format (got ~={spec.version})"

    return True, ""


def check_dependencies(data: Dict) -> List[str]:
    """
    Check all project dependencies for well-defined version ranges.

    Returns:
        List of error messages for violations found
    """
    violations = []

    # Check project.dependencies
    if "project" in data and "dependencies" in data["project"]:
        violations.extend(
            check_dependency_section(
                data["project"]["dependencies"],
                "project.dependencies",
            )
        )

    # Check project.optional-dependencies
    if "project" in data and "optional-dependencies" in data["project"]:
        optional_deps = data["project"]["optional-dependencies"]
        for group_name, deps in optional_deps.items():
            violations.extend(
                check_dependency_section(
                    deps,
                    f"project.optional-dependencies[{group_name}]",
                )
            )

    return violations


def check_dependency_section(deps: List[str], section_name: str) -> List[str]:
    """Check a single dependency section for violations."""
    violations = []

    for dep_line in deps:
        # Parse package name and version specifier
        # Format: "package_name[extras]version_spec" or just "package_name"
        parts = dep_line.split(";", 1)  # Split on environment marker
        dep_spec = parts[0].strip()

        # Extract package name and version specifier
        # Find where the version specifier starts (first operator: >, <, =, ~, !)
        package_name = ""
        version_spec = ""

        for i, char in enumerate(dep_spec):
            if char in (">=", "<", "=", "~", "!", ">", "<") or dep_spec[i:].startswith((">=", "<=", "~=")):
                package_name = dep_spec[:i].strip()
                version_spec = dep_spec[i:].strip()
                break
        else:
            # No version specifier found
            package_name = dep_spec.strip()
            version_spec = ""

        # Remove environment markers from package name if present
        if "[" in package_name:
            package_name = package_name.split("[")[0]

        is_valid, reason = is_well_defined(version_spec)

        if not is_valid:
            violations.append(
                f"{section_name}: '{package_name}' has {reason} (specifier: '{version_spec}')"
            )

    return violations


def main() -> int:
    """Main entry point."""
    data = load_pyproject()
    violations = check_dependencies(data)

    if violations:
        print("❌ Dependency validation failed:\n")
        for violation in violations:
            print(f"  {violation}")
        print("\nAll dependencies must have both lower and upper bounds.")
        print("Acceptable formats: >=X.Y,<X.Z or ~=X.Y.Z (compatible release)")
        return 1

    print("✅ All dependencies have well-defined version ranges")
    return 0


if __name__ == "__main__":
    sys.exit(main())
