AUTHOR
FIREWHEEL Team
DONE
DESCRIPTION

Pull a file or directory from a VM using the VM resource handler. This does not require that
the VM is running a SSH server. Also, unlike :ref:`helper_scp`, there is no need
to use the :ref:`control_network_mc` model component since the VM resource handler has access to
the VM through a serial port.

All files get placed at the location specified on the command line.

**Usage:**  ``firewheel pull file [-h] <filename> <vm hostname> <destination>``

Arguments
+++++++++

Named Arguments
^^^^^^^^^^^^^^^

.. option:: -h, --help

    Show help message and exit.

.. option:: --attempts

    :default: 24

    Number of 5 second attempts to try before giving up

Positional Arguments
^^^^^^^^^^^^^^^^^^^^

.. option:: <filename>

    The name of the file or directory on the VM to extract

.. option:: <vm hostname>

    The hostname of the VM within the experiment from which to pull the files

.. option:: <destination>

    Local destination path for the files that were extracted from the VM

Example
+++++++

``firewheel pull file /tmp/test.txt host.root.net /tmp/myfile.txt``

``firewheel pull file /tmp/test host.root.net /tmp/mydir``

DONE
RUN LocalPython ON control
#!/usr/bin/env python

import sys
import shlex
import pickle
import argparse
import subprocess
from time import sleep
from pathlib import Path, PureWindowsPath

from rich.console import Console
from rich.progress import Progress, TextColumn, SpinnerColumn, TimeElapsedColumn

from firewheel.config import config
from firewheel.lib.minimega.api import minimegaAPI
from firewheel.vm_resource_manager.schedule_db import ScheduleDb
from firewheel.vm_resource_manager.schedule_entry import ScheduleEntry


def handle_schedule_entry(name, file_name):
    """Add a file transfer entry to the schedule for a specified VM.

    This function retrieves the schedule for a given virtual machine (VM)
    from the database, creates a new schedule entry for transferring a
    specified file, and updates the schedule in the database.

    Args:
        name (str): The name of the VM for which the schedule is being updated.
        file_name (str): The name of the file to be transferred.
    """
    con = Console()
    schedule_db = ScheduleDb()
    pickled_schedule = schedule_db.get(name)
    if not pickled_schedule:
        con.print(f"[b red]Unable to get schedule for VM: [cyan]{name}")
        sys.exit(1)
    schedule = pickle.loads(pickled_schedule)

    # ScheduleEntry will have been loaded prior to this point automatically.
    se = ScheduleEntry(-100000)
    se.add_file_transfer(file_name, interval=None)
    schedule.append(se)
    schedule_db.put(name, pickle.dumps(schedule), None)


def call_scp(hostname, src, dest, max_attempts=24):
    """Transfer a file from a remote host to a local destination using SCP.

    This function checks if the specified file is ready for transfer on the
    remote host and then uses SCP to copy the file to the local destination.
    If the file is not ready, it will retry up to a specified number of
    attempts.

    Args:
        hostname (str): The hostname or IP address of the remote host.
        src (Path): The source file path on the remote host.
        dest (Path): The destination path on the local machine.
        max_attempts (int, optional): The maximum number of attempts to check
            if the file is ready. Defaults to 24.
    """
    con = Console()
    count = 1
    num_sec = 5
    lockfile = src.parent / f"{src.name}-lock"
    with Progress(
        TextColumn(
            "[yellow]Checking if the file is ready..."
        ),
        SpinnerColumn(spinner_name="line"),
        TimeElapsedColumn(),
        console=con,
    ) as progress:
        progress.add_task(description="check file")
        while True:
            ex = subprocess.run(  # noqa: S602
                shlex.split(
                    "ssh -o LogLevel=error -o UserKnownHostsFile=/dev/null "
                    f'-o StrictHostKeyChecking=no {hostname} "if [ -e {src} ] && '
                    f'[ ! -e {lockfile} ]; then echo True; fi"'
                ),
                capture_output=True,
                check=True,
            )
            out = ex.stdout.decode("utf-8").strip()
            if out == "True":
                break
            if count >= max_attempts:
                console.print(str(
                    f"[b red]Unable to get file: [cyan]{src}[/cyan] from host "
                    f"[cyan]{hostname}[/cyan]. If the file is very large, you may need "
                    "to increase the [magenta]--attempts[/magenta]. "
                    f"Each attempt waits for [cyan]{num_sec}[/cyan] seconds."
                ))
                sys.exit(1)
            count += 1
            sleep(num_sec)
    subprocess.run(  # noqa: S602
        "scp -r -o LogLevel=error -o UserKnownHostsFile=/dev/null "
        f"-o StrictHostKeyChecking=no {hostname}:{src} {dest}",
        shell=True,
        check=True,
    )
    subprocess.run(  # noqa: S602
        "ssh -o LogLevel=error -o UserKnownHostsFile=/dev/null "
        f'-o StrictHostKeyChecking=no {hostname} "rm -rf {src}"',
        shell=True,
        check=True,
    )
    sys.exit(0)


if __name__ == "__main__":

    # Set the arguments
    parser = argparse.ArgumentParser(
        description="Extract files from VMs within a FIREWHEEL experiment",
        prog="firewheel pull file",
    )

    parser.add_argument(
        "--attempts",
        type=int,
        default=24,
        help="Number of 5 second attempts to try before giving up",
    )

    parser.add_argument(
        "filename",
        help="The name of the file or directory on the VM to extract",
    )

    parser.add_argument(
        "vm_name",
        help="The hostname of the VM within the experiment from which to pull the files",
    )

    parser.add_argument(
        "destination",
        help="Local destination path for the files that were extracted from the VM",
    )

    args = parser.parse_args()

    filename = args.filename
    vm_name = args.vm_name
    destination = args.destination

    local_dest = Path(destination)
    if not local_dest.exists():
        local_dest.parent.mkdir(parents=True, exist_ok=True)

    handle_schedule_entry(vm_name, filename)

    mm_api = minimegaAPI()
    mm_dict = mm_api.mm_vms()
    host = mm_dict[vm_name]["hostname"]

    console = Console()
    path = None
    # Check if the VM is a Windows VM based on image name
    if "windows" in mm_dict[vm_name]["image"]:
        if not PureWindowsPath(filename.strip("/")).drive:
            console.print(
                f"[b red]Path from Windows VM [cyan]{vm_name}[/cyan] does not have a drive. "
                "Assuming [magenta]C:[/magenta]"
            )
            path = (
                Path(config["logging"]["root_dir"])
                / "transfers"
                / vm_name
                / PureWindowsPath("C:", filename).as_posix()
            )

    if path is None:
        path = (
            Path(config["logging"]["root_dir"])
            / "transfers"
            / vm_name
            / PureWindowsPath(filename.strip("/")).as_posix()
        )

    call_scp(host, path, local_dest, max_attempts=args.attempts)
DONE
