#!/usr/bin/env python

import argparse
import curses
import json
import os
import re
import sys
from itertools import chain
from subprocess import run, PIPE, CalledProcessError
from textwrap import wrap

__version__ = "0.3"
__scriptname__ = os.path.basename(__file__)

SHELL = os.environ.get("SHELL", "/bin/sh")

DEFAULT_TITLE = " blkmenu "

DEFAULT_COLUMNS = [
    "name",
    ">size",
    "type",
    "fstype",
    "partlabel",
    "label",
    "rm",
    "ro",
    "mountpoint",
]

DEFAULT_ACTIONS = {
    "q": "quit",
    "j": "movedown",
    "KEY_DOWN": "movedown",
    "k": "moveup",
    "KEY_UP": "moveup",
    "m": "mount",
    "u": "unmount",
    "l": "lock",
    "L": "unlock",
    "e": "eject",
    "o": "open",
    "i": "info",
    "r": "refresh",
    "a": "toggle_filter",
    "?": "help",
}

ACTION_DESCRIPTIONS = {
    "quit": "Exit the program",
    "movedown": "Move down a line",
    "moveup": "Move up a line",
    "mount": "Mount the selected device",
    "unmount": "Unmount the selected device",
    "lock": "Lock the selected device",
    "unlock": "Unlock the selected device",
    "eject": "Eject the selected device",
    "open": f"Open the selected device mountpoint with the given --open command (defaults to '{SHELL}')",
    "info": "Display all device properties",
    "refresh": "Refresh the device tree.",
    "toggle_filter": "Toggle active filters (set with -a or -A options)",
    "help": "Display this help screen",
}

VALID_ACTIONS = set(DEFAULT_ACTIONS.values())


class DeviceError(Exception):
    pass


class CmdError(Exception):
    pass


class ActionError(Exception):
    pass


def fatal(message):
    raise SystemExit(message)


def vifm(server_name, *commands):
    """Executes the given commands in the given vifm instance.

    Args:
        server_name (str): The vifm instance name.
        commands (list):  List of commands to execute.
    """
    # --remote must be placed after --server-name
    cmd = ["vifm", "--server-name", server_name, "--remote"]
    cmd.extend(chain(*(("-c", c) for c in commands)))
    try:
        run(cmd, check=True, encoding="utf-8")
    except FileNotFoundError:
        raise CmdError("vifm: command not found")
    except CalledProcessError as e:
        raise CmdError(f"vifm: command failed with code {e.returncode}: {' '.join(cmd)}")


class Device:
    """Device object.

    Attributes:
        parent (Device): The parent device or None for a root node.
        children (list): Device children.
        info (dict): Device properties.
        tree_padding (str): Tree decorations that represent the device hierarchy. Make
            sense only in the context of a list returned by DeviceTree.list().
        root (bool): Whether or not the device is root. The root device holds no info.
    """

    CMD = "udisksctl"
    VIEWER = "less"

    def __init__(self, info, parent=None, root=False):
        self.parent = parent
        self.children = info.pop("children", [])
        self.info = info
        self.tree_padding = ""
        self.root = root

    def mount(self, options=None, fstype=None):
        """Mounts device.

        Args:
            options (str): Mount options.
            fstype (str): Filesystem type to use. As per udisksctl docs, if not
                specified, autodetected filesystem type will be used.
        """
        fstype = [] if fstype is None else ["-f", fstype]
        options = [] if options is None else ["-o", options]
        return self._do("mount", "-b", self.info["path"], *options, *fstype)

    def unmount(self):
        """Unmounts device."""
        return self._do("unmount", "-b", self.info["path"])

    def eject(self):
        """Ejects device."""
        return self._do("power-off", "-b", self.info["path"])

    def lock(self):
        """Locks device."""
        return self._do("lock", "-b", self.info["path"])

    def unlock(self):
        """Unlocks device."""
        return self._do("unlock", "-b", self.info["path"])

    def view_details(self):
        """Views all device properties with `less`.

        Raises:
            CmdError: The VIEWER command was not found or returned an error.
        """
        fmt = "{{:>{}}}  {{}}".format(max(map(len, self.info.keys())) + 2)
        input = "\n".join(fmt.format(k.upper(), v) for k, v in self.info.items())
        try:
            run([self.VIEWER], input=input, check=True, encoding="utf-8")
        except FileNotFoundError:
            raise CmdError(f"{self.VIEWER}: command not found")
        except CalledProcessError as e:
            raise CmdError(f"{self.VIEWER}: exited with code {e.returncode}")

    def open(self, cmd):
        """Opens the device mountpoint with the given command.

        Args:
            cmd (str): Command used to open the device mountpoint.

        Raises:
            CmdError: The given command was not found or returned an error.
            DeviceError: The device was not mounted.
        """
        mountpoint = self.info["mountpoint"]
        if not mountpoint:
            raise DeviceError(f"Device '{self.info['name']}' not mounted.")
        cmd = cmd.format(repr(mountpoint))
        try:
            run(cmd, cwd=mountpoint, shell=True, check=True, encoding="utf-8")
        except CalledProcessError as e:
            raise CmdError(e)

    def _do(self, *args):
        """Runs udisksctl command.

        Args:
            args (list): Udisksctl arguments.

        Returns:
            The `udisksctl` process standard output.

        Raises:
            CmdError: The `udisksctl` command was not found.
            DeviecError: The `udiskctl` command returned an error.
        """
        try:
            cmd = [self.CMD] + list(args)
            proc = run(cmd, check=True, stdout=PIPE, stderr=PIPE, encoding="utf8")
        except FileNotFoundError:
            raise CmdError(f"{self.CMD}: command not found")
        except CalledProcessError as e:
            raise DeviceError(e.stderr.strip())
        return proc.stdout.strip()

    def __str__(self):
        return f"<{self.__class__.__name__} {self.info.get('path')}>"

    def __repr__(self):
        return f"{self.__class__.__name__}(path={self.info.get('path')!r})"


class DeviceTree:
    """Device tree.

    Device tree parsed from the json output of lsblk (lsblk -JO).

    Attributes:
        filters (list): List of python expressions for filteing devices. If any of these
            expressions evaluates to True for a device, the device is removed from the
            tree and its children are become its parent's.
        prunes (list): List of python expressions for pruning devices. If any of these
            expressions evaluates to True for a device, the device and all its
            descendants are removed from the tree.
        _tree (dict): The device tree.
    """

    def __init__(self, filters=None, prunes=None, filter=True):
        self._filters = filters or []
        self._prunes = prunes or []
        self._tree = self._get_device_tree(filter=filter)

    def refresh(self, filter=True):
        """Rebuilds the device tree.

        Args:
            filter (bool): Whether or not to filter devices.
        """
        self._tree = self._get_device_tree(filter=filter)

    def list(self):
        """Returns the device tree as a list.

        Returns:
            A list of devices.

        Note:
            The list must not be sorted afterwards or the tree padding wouldn't make
            sense anymore.
        """

        def _flatten(device, padding, last_child):

            curr_padding = ""
            next_padding = padding

            if not device.parent.root:
                curr_padding = padding + ("└─ " if last_child else "├─ ")
                next_padding = padding + ("   " if last_child else "│  ")

            device.tree_padding = curr_padding
            yield device

            for i, child in enumerate(device.children):
                yield from _flatten(child, next_padding, i == len(device.children) - 1)

        devices = self._tree.children
        return list(chain(*[_flatten(d, "", 0) for d in devices]))

    def _match(self, device, rules):
        """Returns True if any of the expression rules evaluates to True.

        Args:
            device (Device): The target device to match against the rules.
            rules (list): List of python expressions.

        Returns:
            True if any of the given expression rules evaluates to True.
        """
        locals = device.info
        globals = {"__builtins__": None, "match": re.match, "search": re.search}
        for rule in rules:
            try:
                if eval(rule, globals, locals):
                    return True
            except Exception:
                continue
        return False

    def _lsblk(self):
        """Deserializes the `lsblk` json output into a dictionary.

        Returns:
            A raw dict object representing the device tree.
        """
        try:
            proc = run(["lsblk", "-JO"], stdout=PIPE, check=True, encoding="utf8")
        except FileNotFoundError:
            fatal(f"lsblk: command not found")
        except CalledProcessError as e:
            fatal(f"lsblk: command failed with exit code: {e.returncode}")
        return json.loads(proc.stdout)

    def _get_device_tree(self, filter=True):
        """Builds the device tree from the `lsblk` output.

        Args:
            filter (bool): Whether or not to filter devices.

        Returns:
            The device tree.
        """

        def _build_tree(node):
            if not node.root and filter and self._match(node, self._prunes):
                return None

            children = []
            for child in node.children:
                subtree = _build_tree(Device(child, parent=node))
                if subtree:
                    children.extend(subtree if isinstance(subtree, list) else [subtree])

            if not node.root and filter and self._match(node, self._filters):
                for child in children:
                    child.parent = child.parent.parent
                return children

            node.children = children
            return node

        d = self._lsblk()
        root = Device({"children": d["blockdevices"]}, root=True)
        return _build_tree(root)


class Formatter:
    """Tabularize data.

    Attributes:
        data (list): List of entries to format.
        columns (list): Which entry fields to output.
        col_getter (callable): Function used to retrieve entry fields.
        header_transform (callable): Function called to transform the header. Defautls
            to making the each column header uppercase.
        show_header (bool): Whether or not return the header.
        separator (str): Column separator. Used when `stretch == False`.
        width (int): Max table width.
        stretch (bool): Whether or not to make the columsn spane the whole
    """

    def __init__(
        self,
        data=None,
        columns=None,
        col_getter=None,
        header_transform=None,
        show_header=True,
        separator=" ",
        width=0,
        stretch=False,
    ):
        self.data = data or []
        self.columns = columns or []
        self.col_getter = col_getter or (lambda entry, col: entry[col])
        self.show_header = show_header
        self.header_transform = header_transform or (lambda col: col.upper())
        self.width = width
        self.stretch = stretch
        self.separator = separator

    def format(self):
        """Format the data a a table.

        Returns:
            A tuple containing the header string and a list all formatted entries:
            (header, lines)
        """

        header = []
        used_width = 0

        # find the max width of each column
        formats = {}
        for col in self.columns:
            align = "<"
            if col.startswith("<") or col.startswith(">"):
                align, col = col[0], col[1:]
            col_width = len(col) if self.show_header else 0
            maxcol = max(
                [len(self.col_getter(entry, col)) for entry in self.data] + [col_width]
            )
            formats[col] = "{{:{}{}}}".format(align, maxcol)
            if self.show_header:
                header.append(formats[col].format(self.header_transform(col)))
            used_width += maxcol

        if self.stretch:
            # arrange columns to span the whole screen width
            avail_width = max(self.width - used_width, 0)
            sep = " " * max(avail_width // max(len(self.columns) - 1, 1), 1)
        else:
            sep = self.separator

        header = sep.join(header)

        lines = []
        for entry in self.data:
            line = []
            for col in map(lambda s: s.lstrip("<>"), self.columns):
                val = self.col_getter(entry, col)
                line.append(formats[col].format(val))
            lines.append(sep.join(line))

        return header, lines


class Menu:
    """Device menu.

    Attributes:
        stdscr (curses.window): Window object representing the whole screen.
        args (argparse.Namespace): Command line arguments.
        actions (dict): Map of keys/actions.
        filter (bool): Whether or not to filter devices.
        device_tree (DeviceTree): Device tree.
        device_list (list): Device tree as a flat list.
        selected (int): Selected device as an index of `device_list`.
        error (str): Current error message.
        message (str): Feedback message.
        refresh (bool): Whether or not to refresh the device tree after a key press.
    """

    def __init__(self, stdscr, actions, args):

        self.stdscr = stdscr
        self.args = args
        self.actions = actions

        self.filter = bool(args.filters or args.prunes)
        self.device_tree = DeviceTree(
            filters=args.filters, prunes=args.prunes, filter=filter
        )
        self.device_list = []
        self.selected = -1

        self.init_state()

        curses.curs_set(False)
        curses.use_default_colors()
        curses.init_pair(1, curses.COLOR_RED, -1)

    def init_state(self):
        self.error = ""
        self.message = ""
        self.refresh = False

    def do(self, device, action):
        """Performs an action on the given device.

        Args:
            device (Device): The target device.
            action (str): The action to perform.

        Raises:
            ActionError: Action failed.
            SystemExit: User decided to quit or open the device mountpoint with vifm.

        Returns:
            Whether or not the action was performed successfully.
        """

        if action == "quit":
            sys.exit()

        elif action == "help":
            self.view_help()

        elif self.device_list and action == "moveup":
            self.selected = (self.selected - 1) % len(self.device_list)

        elif self.device_list and action == "movedown":
            self.selected = (self.selected + 1) % len(self.device_list)

        elif action == "refresh":
            self.refresh = True

        elif action == "toggle_filter":
            self.filter = not self.filter
            self.refresh = True

        elif device and action == "mount":
            try:
                options = self.args.mount_opts
                self.message = device.mount(options=options)
                self.refresh = True
            except (CmdError, DeviceError) as e:
                raise ActionError(e)

        if device and action == "unmount":
            try:
                self.message = device.unmount()
                self.refresh = True
            except (CmdError, DeviceError) as e:
                raise ActionError(e)

        if device and action == "eject":
            try:
                self.message = device.eject()
                self.refresh = True
            except (CmdError, DeviceError) as e:
                raise ActionError(e)

        if device and action == "lock":
            try:
                self.message = device.lock()
                self.refresh = True
            except (CmdError, DeviceError) as e:
                raise ActionError(e)

        if device and action == "unlock":
            curses.savetty()
            curses.endwin()
            try:
                self.message = device.unlock()
                self.refresh = True
            except (CmdError, DeviceError) as e:
                raise ActionError(e)
            finally:
                curses.resetty()

        if device and action == "info":
            curses.savetty()
            curses.endwin()
            try:
                device.view_details()
            except CmdError as e:
                raise ActionError(e)
            finally:
                curses.resetty()

        if device and action == "open":
            if self.args.vifm:
                try:
                    mountpoint = device.info["mountpoint"]
                    if not mountpoint:
                        raise ActionError(f"Device '{device.info['name']}' not mounted.")
                    vifm(self.args.vifm, f"cd {mountpoint!r}")
                    sys.exit()
                except CmdError as e:
                    raise ActionError(e)
            else:
                curses.savetty()
                curses.endwin()
                try:
                    device.open(self.args.open_cmd)
                except (CmdError, DeviceError) as e:
                    raise ActionError(e)
                finally:
                    curses.resetty()

    def start(self):
        """Starts the main loop."""
        self.device_list = self.device_tree.list()
        self.selected = 0

        while True:

            self.draw()

            key = self.stdscr.getkey()

            self.init_state()

            if key not in self.actions:
                continue

            action = self.actions[key]
            if action not in VALID_ACTIONS:
                self.error = f"Unknown action: {action}"
                continue

            try:
                device = self.device_list[self.selected]
            except IndexError:
                device = None

            try:
                self.do(device, action)
            except ActionError as e:
                self.error = self.format_error(e, device)

            if self.refresh:
                self.device_tree.refresh(filter=self.filter)
                self.device_list = self.device_tree.list()

    def format_error(self, error, device):
        """Transforms an ActionError into a legible error message.

        Args:
            error (ActionError): The ActionError to format.
            device (Device): The device for which the action failed.

        Returns:
            A nicely formatted error message.
        """
        msg = str(error)
        msg = msg.replace("`", "'")
        msg = re.sub("GDBus.*?: ", "", msg, flags=re.IGNORECASE)
        msg = re.sub(r"Object /\S+", device.info["path"], msg, flags=re.IGNORECASE)
        return msg

    def view_help(self):
        """Displays the help and waits for keypress to exit."""
        self.stdscr.erase()
        y, x = self.draw_border(0, 0, title=" help ")

        entries = {}
        for key, action in self.actions.items():
            entries.setdefault(
                action,
                {"action": action, "keys": [], "description": ACTION_DESCRIPTIONS[action]},
            )
            entries[action]["keys"].append(key)

        header, lines = Formatter(
            data=entries.values(),
            col_getter=lambda e, c: ", ".join(e[c]) if c == "keys" else e[c],
            columns=["action", "keys", "description"],
            separator="    ",
            stretch=False,
        ).format()

        width = self.stdscr.getmaxyx()[1] - (x * 2)

        if header:
            self.stdscr.addstr(y, x, header[:width], curses.A_BOLD)
            y += 1

        for line in lines:
            self.stdscr.addstr(y, x, line[:width])
            y += 1

        self.stdscr.refresh()
        self.stdscr.getkey()

    def draw_border(self, y, x, title=""):
        """Draws the window border.

        Args:
            x, y (int): Window coordinates that indicate where to start drawing.
            title (str): The window title to display no the border at the top-left
                corner.

        Returns:
            The updated coordinates.
        """
        if self.args.border:
            self.stdscr.border()
            if title:
                self.stdscr.addstr(y, x + 3, title, curses.A_BOLD)
            y, x = y + 1, x + 2
        return y, x

    def _col_getter(self, dev, prop):
        """Extracts the given property from the device."""
        prop = prop.lstrip("<>")
        try:
            if dev.info[prop] is None:
                return ""
            if dev.info[prop] is False or dev.info[prop] is True:
                return str(int(dev.info[prop]))
            if self.args.tree is not None and prop == self.args.tree:
                return dev.tree_padding + str(dev.info[prop])
            return str(dev.info[prop])
        except KeyError:
            fatal(f"{__scriptname__}: invalid column: {prop}")

    def draw(self):
        """Draws the menu."""

        self.stdscr.erase()
        y, x = self.draw_border(0, 0, title=self.args.title)

        # remove space taken up by the border
        width = self.stdscr.getmaxyx()[1] - (x * 2)

        header, lines = Formatter(
            data=self.device_list,
            columns=self.args.columns,
            col_getter=self._col_getter,
            show_header=self.args.show_header,
            separator=self.args.sep,
            stretch=self.args.stretch,
            width=width,
        ).format()

        if not self.device_list:
            self.stdscr.addstr(y, x, "No device found.")
            y += 1

        if header and self.device_list:
            self.stdscr.addstr(y, x, header[:width], curses.A_BOLD)
            y += 1

        # make sure a device is always selected when toggling filters
        self.selected = max(min(self.selected, len(self.device_list) - 1), 0)

        for i, line in enumerate(lines):
            attr = curses.A_REVERSE if i == self.selected else curses.A_NORMAL
            self.stdscr.addstr(y, x, line[:width], attr)
            y += 1

        y += 1

        color = curses.color_pair(1) if self.error else curses.A_NORMAL
        message = self.error if self.error else self.message
        for line in wrap(message, width=width):
            self.stdscr.addstr(y, x, line, color)
            y += 1

        self.stdscr.refresh()


def parse_args():
    parser = argparse.ArgumentParser(allow_abbrev=False, add_help=False)
    parser.add_argument("-h", "--help", action="help", help="Print help information.")
    parser.add_argument(
        "-V",
        "--version",
        action="version",
        version=f"%(prog)s {__version__}",
        help="Print version information.",
    )
    parser.add_argument("--border", action="store_true", help="Show menu border.")
    parser.add_argument(
        "--no-border",
        action="store_false",
        dest="border",
        help="Don't show menu border. This is the default behavior.",
    )
    parser.add_argument(
        "--title",
        default=DEFAULT_TITLE,
        help="Set menu title. Shown only when --border is used.",
    )
    parser.add_argument(
        "--no-title",
        action="store_const",
        const="",
        dest="title",
        help="Don't show menu title. Same as --title ''.",
    )
    parser.add_argument(
        "--columns", default=DEFAULT_COLUMNS, help=f"Show only the given columns."
    )
    parser.add_argument(
        "--header",
        action="store_true",
        default=True,
        dest="show_header",
        help="Show header. This is the default behavior.",
    )
    parser.add_argument(
        "--no-header", action="store_false", dest="show_header", help="Hide header."
    )
    parser.add_argument(
        "--tree",
        metavar="COLUMN_NAME",
        nargs="?",
        default="",
        const="",
        help="The column displayed as hierarchy tree. Defaults to the first column.",
    )
    parser.add_argument(
        "--flat",
        action="store_const",
        const=None,
        dest="tree",
        help="Don't display hierarchy relationships.",
    )
    parser.add_argument(
        "--stretch",
        action="store_true",
        default=True,
        help="Stretch columns to fill the whole window. This is the default behavior.",
    )
    parser.add_argument(
        "--no-stretch",
        action="store_false",
        dest="stretch",
        help="Don't stretch columns to fill the whole window.",
    )
    parser.add_argument(
        "--sep",
        metavar="SEPARATOR",
        default="  ",
        help="Columns separator when --no-stretch is given.",
    )
    parser.add_argument(
        "--open",
        metavar="CMD",
        default=SHELL,
        dest="open_cmd",
        help="Command used to open a device mountpoint.",
    )
    parser.add_argument(
        "--mount-opts",
        metavar="OPTS",
        default="nosuid,noexec,noatime",
        help="Default mount options.",
    )
    parser.add_argument(
        "--vifm",
        metavar="SERVER_NAME",
        nargs="?",
        const="vifm",
        help="Vifm server name. Bypass the --open option and open the device mountpoint directly in the vifm instance named SERVER_NAME.",
    )
    parser.add_argument(
        "-f",
        metavar="EXPR",
        dest="filters",
        nargs="+",
        action="append",
        default=[],
        help="Exclude devices from the menu. Treated as python expressions.",
    )
    parser.add_argument(
        "-p",
        metavar="EXPR",
        dest="prunes",
        nargs="+",
        action="append",
        default=[],
        help="Exclude devices and all their descendants from the menu. Treated as python expressions.",
    )
    parser.add_argument(
        "-a",
        metavar="KEY:ACTION",
        dest="actions",
        nargs="+",
        action="append",
        default=[],
        help=f"Bindings as key:action pairs.",
    )
    parser.add_argument(
        "-A",
        metavar="KEY:ACTION",
        dest="actions_override",
        nargs="+",
        action="append",
        default=[],
        help="Same as -a but clears other bindings that map to the same action.",
    )

    args = parser.parse_args()

    args.prunes = list(chain(*args.prunes))
    args.filters = list(chain(*args.filters))

    args.actions = list(chain(*args.actions))
    args.actions_override = list(chain(*args.actions_override))

    if isinstance(args.columns, str):
        if args.columns:
            args.columns = args.columns.split(",")
        else:
            args.columns = DEFAULT_COLUMNS

    if args.tree == "":
        args.tree = args.columns[0].lstrip("<>")

    return args


def start_menu(stdscr, actions, args):
    menu = Menu(stdscr, actions, args)
    menu.start()


def _parse_binding_pair(pair):
    key, _, action = pair.rpartition(":")
    if not key or not action:
        fatal(f"{__scriptname__}: invalid 'key:action' pair: {pair}")
    return key, action


def main():

    args = parse_args()

    actions = DEFAULT_ACTIONS

    # force an action to be mapped to only one key
    for pair in args.actions_override:
        key, action = _parse_binding_pair(pair)
        for k, a in list(actions.items()):
            if a == action:
                del actions[k]
        actions[key] = action

    for pair in args.actions:
        key, action = _parse_binding_pair(pair)
        actions[key] = action

    curses.wrapper(start_menu, actions, args)


if __name__ == "__main__":
    main()
