#!/usr/bin/env python3
"""boneio-migrate — privileged system migration helper.

This script runs as root (via NOPASSWD sudo) and applies declarative
migration plans sent by the ``boneio`` application via stdin (JSON).

It validates SHA-256 hashes before writing any file and logs all actions
to ``/var/log/boneio-migrate.log``.

SECURITY:
- Never reads passwords or secrets.
- Only accepts actions from the whitelist defined in ``ALLOWED_ACTIONS``.
- Validates asset SHA-256 before writing.
- Runs only as root (enforced at startup).
- Validates sudoers fragments with ``visudo -cf`` before installing.
"""

from __future__ import annotations

import hashlib
import json
import logging
import logging.handlers
import os
import pwd
import subprocess
import sys
import tempfile
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

# ---------------------------------------------------------------------------
# Setup
# ---------------------------------------------------------------------------

LOG_FILE = "/var/log/boneio-migrate.log"
APPLIED_DIR = Path("/var/lib/boneio/migrations.d")

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.StreamHandler(sys.stderr),
        logging.handlers.RotatingFileHandler(
            LOG_FILE, maxBytes=1_048_576, backupCount=1, encoding="utf-8"
        ),
    ],
)
_LOGGER = logging.getLogger("boneio-migrate")

ALLOWED_ACTIONS = {
    "install_file",
    "remove_file",
    "systemctl_daemon_reload",
    "systemctl_enable",
    "systemctl_disable",
    "systemctl_restart",
    "systemctl_reload",
    "append_line_if_missing",
}


# ---------------------------------------------------------------------------
# Guards
# ---------------------------------------------------------------------------

def _assert_root() -> None:
    """Ensure the helper is running as root."""
    if os.geteuid() != 0:
        _LOGGER.error("boneio-migrate must be run as root via sudo.")
        sys.exit(1)


# ---------------------------------------------------------------------------
# SHA-256 helpers
# ---------------------------------------------------------------------------

def _sha256_of_bytes(data: bytes) -> str:
    """Return lower-case hex SHA-256 of data bytes."""
    return hashlib.sha256(data).hexdigest()


def _render_template(content: str, template_vars: dict[str, str]) -> str:
    """Apply ${KEY} substitutions to content."""
    from string import Template
    if not template_vars:
        return content
    return Template(content).safe_substitute(template_vars)


# ---------------------------------------------------------------------------
# Action handlers
# ---------------------------------------------------------------------------

def handle_install_file(action: dict[str, Any], assets_base: str) -> None:
    """Install a file from assets_base/src to dst with permission checks.

    Args:
        action: Action dict with src, dst, mode, owner, group, etc.
        assets_base: Absolute path to the assets directory.
    """
    src_rel = action["src"]
    dst = action["dst"]
    mode = int(action.get("mode", 0o644))
    owner = action.get("owner", "root")
    group = action.get("group", "root")
    template_vars = action.get("template_vars", {})
    validate_cmd = action.get("validate_cmd")
    expected_sha256 = action.get("expected_sha256")

    src_path = Path(assets_base) / src_rel
    if not src_path.exists():
        raise FileNotFoundError(f"Asset not found: {src_path}")

    raw_content = src_path.read_bytes()
    if template_vars:
        rendered = _render_template(raw_content.decode("utf-8"), template_vars)
        content = rendered.encode("utf-8")
    else:
        content = raw_content

    # Validate SHA-256 if provided
    if expected_sha256:
        actual = _sha256_of_bytes(raw_content)  # manifest is hash of source
        if actual != expected_sha256:
            raise ValueError(
                f"SHA-256 mismatch for {src_rel}: "
                f"expected={expected_sha256} actual={actual}"
            )

    # Check if destination already matches (idempotency)
    dst_path = Path(dst)
    if dst_path.exists():
        existing = dst_path.read_bytes()
        if existing == content:
            _LOGGER.info("SKIP (unchanged): %s", dst)
            return

    # Write to temp file first, then validate, then move (atomic)
    dst_path.parent.mkdir(parents=True, exist_ok=True)

    with tempfile.NamedTemporaryFile(
        dir=dst_path.parent, delete=False, suffix=".tmp"
    ) as tmp:
        tmp.write(content)
        tmp_path = tmp.name

    try:
        # Set permissions on temp file before move
        os.chmod(tmp_path, mode)

        # Run validate_cmd if provided (e.g. visudo -cf <path>)
        if validate_cmd:
            result = subprocess.run(
                [*validate_cmd.split(), tmp_path],
                capture_output=True,
                text=True,
                timeout=10,
            )
            if result.returncode != 0:
                raise ValueError(
                    f"Validation failed for {dst}: {result.stderr.strip()}"
                )

        # Atomic move
        os.replace(tmp_path, dst)
    except Exception:
        # Clean up temp file on any error
        try:
            os.unlink(tmp_path)
        except OSError:
            pass
        raise

    # Set owner/group
    try:
        uid = pwd.getpwnam(owner).pw_uid
        import grp
        gid = grp.getgrnam(group).gr_gid
        os.chown(dst, uid, gid)
    except (KeyError, OSError) as exc:
        _LOGGER.warning("Could not set owner %s:%s on %s: %s", owner, group, dst, exc)

    _LOGGER.info("WRITE: %s (mode=%o owner=%s:%s)", dst, mode, owner, group)

    # Execute on_change action if file was changed
    on_change = action.get("on_change")
    if on_change:
        _LOGGER.info("Executing on_change action for %s", dst)
        dispatch_action(on_change, assets_base)


def handle_remove_file(action: dict[str, Any], assets_base: str) -> None:
    """Remove a file if it exists.

    Args:
        action: Action dict with path.
        assets_base: Unused, for signature consistency.
    """
    path = Path(action["path"])
    if path.exists():
        path.unlink()
        _LOGGER.info("REMOVED: %s", path)
    else:
        _LOGGER.info("SKIP (not found): %s", path)


def handle_systemctl_daemon_reload(action: dict[str, Any], assets_base: str) -> None:
    """Run systemctl daemon-reload."""
    _LOGGER.info("systemctl daemon-reload")
    subprocess.run(["systemctl", "daemon-reload"], check=True, timeout=30)


def handle_systemctl_enable(action: dict[str, Any], assets_base: str) -> None:
    """Enable a systemd unit.

    Args:
        action: Action dict with unit.
        assets_base: Unused.
    """
    unit = action["unit"]
    _LOGGER.info("systemctl enable %s", unit)
    subprocess.run(["systemctl", "enable", unit], check=True, timeout=15)


def handle_systemctl_disable(action: dict[str, Any], assets_base: str) -> None:
    """Disable a systemd unit.

    Args:
        action: Action dict with unit.
        assets_base: Unused.
    """
    unit = action["unit"]
    _LOGGER.info("systemctl disable %s", unit)
    subprocess.run(["systemctl", "disable", unit], check=False, timeout=15)


def handle_systemctl_restart(action: dict[str, Any], assets_base: str) -> None:
    """Restart a systemd service.

    Args:
        action: Action dict with unit.
        assets_base: Unused.
    """
    unit = action["unit"]
    _LOGGER.info("systemctl restart %s", unit)
    subprocess.run(["systemctl", "restart", unit], check=True, timeout=30)


def handle_systemctl_reload(action: dict[str, Any], assets_base: str) -> None:
    """Reload a systemd service.

    Args:
        action: Action dict with unit.
        assets_base: Unused.
    """
    unit = action["unit"]
    _LOGGER.info("systemctl reload %s", unit)
    subprocess.run(["systemctl", "reload", unit], check=True, timeout=30)


def handle_append_line_if_missing(action: dict[str, Any], assets_base: str) -> None:
    """Append a line to a file if not already present.

    Args:
        action: Action dict with path and line.
        assets_base: Unused.
    """
    path = Path(action["path"])
    line = action["line"]

    if path.exists():
        existing = path.read_text(encoding="utf-8")
        if line in existing.splitlines():
            _LOGGER.info("SKIP (line present): %s in %s", line, path)
            return

    with path.open("a", encoding="utf-8") as f:
        f.write(f"\n{line}\n")
    _LOGGER.info("APPENDED: %s to %s", line, path)


# ---------------------------------------------------------------------------
# Dispatcher
# ---------------------------------------------------------------------------

ACTION_HANDLERS = {
    "install_file": handle_install_file,
    "remove_file": handle_remove_file,
    "systemctl_daemon_reload": handle_systemctl_daemon_reload,
    "systemctl_enable": handle_systemctl_enable,
    "systemctl_disable": handle_systemctl_disable,
    "systemctl_restart": handle_systemctl_restart,
    "systemctl_reload": handle_systemctl_reload,
    "append_line_if_missing": handle_append_line_if_missing,
}


def dispatch_action(action: dict[str, Any], assets_base: str) -> None:
    """Dispatch a single action to the appropriate handler.

    Args:
        action: Action dict (must have "action" key).
        assets_base: Absolute path to assets directory.

    Raises:
        ValueError: If action type is not in the whitelist.
    """
    action_type = action.get("action")
    if action_type not in ALLOWED_ACTIONS:
        raise ValueError(f"Disallowed action type: {action_type!r}")

    handler = ACTION_HANDLERS.get(action_type)
    if handler is None:
        raise ValueError(f"No handler for action: {action_type!r}")

    handler(action, assets_base)


# ---------------------------------------------------------------------------
# Applied flag
# ---------------------------------------------------------------------------

def write_applied_flag(version: str) -> None:
    """Write the .applied flag file for a migration version.

    Args:
        version: Migration version string (e.g. "1.3.0").
    """
    APPLIED_DIR.mkdir(parents=True, exist_ok=True)
    flag_path = APPLIED_DIR / f"{version}.applied"
    flag_path.write_text(
        f"applied_at={datetime.now(timezone.utc).isoformat()}\n",
        encoding="utf-8",
    )
    _LOGGER.info("FLAG: %s", flag_path)


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------

def main() -> int:
    """Entry point: read JSON plan from stdin, validate, apply, flag.

    Returns:
        0 on success, 1 on failure.
    """
    _assert_root()

    _LOGGER.info("boneio-migrate started (pid=%d)", os.getpid())

    try:
        payload_raw = sys.stdin.read()
        payload = json.loads(payload_raw)
    except Exception as exc:
        _LOGGER.error("Failed to parse JSON plan from stdin: %s", exc)
        return 1

    version = payload.get("version")
    actions = payload.get("actions", [])
    assets_base = payload.get("assets_base")

    if not version:
        _LOGGER.error("Plan missing 'version' field.")
        return 1
    if not assets_base or not Path(assets_base).is_dir():
        _LOGGER.error("Plan 'assets_base' is missing or not a directory: %s", assets_base)
        return 1

    _LOGGER.info("Applying migration version=%s (%d actions)", version, len(actions))

    for i, action in enumerate(actions, start=1):
        action_type = action.get("action", "?")
        _LOGGER.info("[%d/%d] %s", i, len(actions), action_type)
        try:
            dispatch_action(action, assets_base)
        except Exception as exc:
            _LOGGER.error("[%d/%d] Action %s failed: %s", i, len(actions), action_type, exc)
            return 1

    write_applied_flag(version)
    _LOGGER.info("Migration %s completed successfully.", version)
    return 0


if __name__ == "__main__":
    sys.exit(main())
