#!/usr/bin/env python3

"""Run the Firedrake smoke tests."""

import argparse
import logging
import importlib.util
import os
import pathlib
import subprocess
import sys


# smoke tests grouped by number of processors
TESTS = {
    1: (
        "tests/firedrake/regression/test_stokes_mini.py::test_stokes_mini",
        # spatialindex
        "tests/firedrake/regression/test_locate_cell.py",
        # supermesh
        "tests/firedrake/supermesh/test_assemble_mixed_mass_matrix.py::test_assemble_mixed_mass_matrix[2-CG-CG-0-0]",
        # fieldsplit
        "tests/firedrake/regression/test_matrix_free.py::test_fieldsplitting[parameters3-cofunc_rhs-variational]",
        # near nullspace
        "tests/firedrake/regression/test_nullspace.py::test_near_nullspace",
    ),
    2: (
        # HDF5/checkpointing
        "tests/firedrake/output/test_io_function.py::test_io_function_base[cell_family_degree2]",
    ),
    3: (
        "tests/firedrake/regression/test_dg_advection.py::test_dg_advection_icosahedral_sphere[nprocs=3]",
        # vertex-only mesh
        "tests/firedrake/regression/test_interpolate_cross_mesh.py::test_interpolate_cross_mesh[extrudedcube-nprocs=3-interpolate_function-0]"
    ),
}


# log to terminal at INFO level
logging.basicConfig(format="%(message)s", level=logging.INFO)


def main() -> None:
    args = parse_args()

    for nprocs, tests in TESTS.items():
        if nprocs == 1:
            logging.info("    Running serial smoke tests")
        else:
            logging.info(f"    Running parallel smoke tests (nprocs={nprocs})")
        run_tests(tests, nprocs, args.mpiexec)
        logging.info("    Tests passed")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mpiexec",
        type=str,
        default="mpiexec -n",
        help=(
            "Command used to launch MPI. The command must end with the flag "
            "taking the number of processors, e.g. '-n' for mpiexec."
        ),
    )
    return parser.parse_args()


def run_tests(tests: tuple[str], nprocs: int, mpiexec: str) -> None:
    # Find the path to firedrake._check. Don't actually import Firedrake here
    # because that will initialise MPI and can prevent us calling 'mpiexec'
    # below. This only causes issues on some systems.
    firedrake_dir = pathlib.Path(
        importlib.util.find_spec("firedrake").origin
    ).parent
    test_dir = firedrake_dir / "_check"

    # Do verbose checking if running on CI and always set no:cacheprovider because
    # we don't want to generate any cache files in $VIRTUAL_ENV/lib/.../firedrake/_check
    if "FIREDRAKE_CI" in os.environ:
        check_flags = "--verbose -p no:cacheprovider"
    else:
        check_flags = "--quiet -p no:cacheprovider"

    cmd = f"{sys.executable} -m pytest {check_flags} -m parallel[{nprocs}]"
    # Only use mpiexec for parallel tests because 'mpiexec -n 1' can hang
    if nprocs > 1:
        cmd = f"{mpiexec} {nprocs} {cmd}"
    cmd = cmd.split()
    cmd.extend(tests)

    subprocess.run(cmd, cwd=test_dir, check=True)


if __name__ == "__main__":
    main()
