#!/usr/bin/env python3
# ---------------------------------------------------------------------------
# Command Line Interface for MTDA
# ---------------------------------------------------------------------------
#
# This software is a part of MTDA.
# Copyright (C) 2023 Siemens Digital Industries Software
#
# ---------------------------------------------------------------------------
# SPDX-License-Identifier: MIT
# ---------------------------------------------------------------------------

# System imports
import daemon
import getopt
import lockfile
import os
import os.path
import requests
import signal
import time
import sys
import zerorpc
import socket
import zeroconf

# Local imports
from mtda.main import MultiTenantDeviceAccess
from mtda.client import Client
from mtda.console.screen import ScreenOutput
import mtda.constants as CONSTS


class Application:

    def __init__(self):
        self.agent = None
        self.remote = "localhost"
        self.logfile = "/var/log/mtda.log"
        self.pidfile = "/var/run/mtda.pid"
        self.exiting = False
        self.channel = "console"
        self.screen = ScreenOutput(self)

    def daemonize(self):
        context = daemon.DaemonContext(
            working_directory=os.getcwd(),
            stdout=open(self.logfile, 'w+'),
            stderr=open(self.logfile, 'w+'),
            umask=0o002,
            pidfile=lockfile.FileLock(self.pidfile)
        )

        context.signal_map = {
            signal.SIGTERM: 'terminate',
            signal.SIGHUP:  'terminate',
        }

        with context:
            status = self.server()
            return status

    def _ip(self):
        s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        s.connect(('192.0.2.0', 0))
        ip = s.getsockname()[0]
        s.close()
        return ip

    def server(self):
        # Start our agent
        status = self.agent.start()
        if status is False:
            return False

        # Start our RPC server
        uri = "tcp://*:%d" % (self.agent.ctrlport)
        s = zerorpc.Server(self.agent, heartbeat=20)
        s.bind(uri)

        # Initialize ZeroConf
        interfaces = zeroconf.InterfaceChoice.All
        zc = zeroconf.Zeroconf(interfaces=interfaces)
        props = {
            'description': 'Multi-Tenant Device Access'
        }
        deviceType = CONSTS.MDNS.TYPE
        name = self.agent.name
        info = zeroconf.ServiceInfo(
                type_=deviceType,
                name='{}.{}'.format(name, deviceType),
                addresses=[socket.inet_aton(self._ip())],
                port=int(self.agent.ctrlport),
                properties=props,
                server='{}.local.'.format(name))

        zc.register_service(info)
        try:
            s.run()
        except KeyboardInterrupt:
            s.stop()
            self.agent.stop()
        finally:
            zc.unregister_service(info)

        return True

    def client(self):
        return self.agent

    def debug(self, level, msg):
        self.client().debug(level, msg)

    def command_cmd(self, args):
        result = self.client().command(args)
        if isinstance(result, str):
            print(result)
        else:
            print("Device command '{0}' failed!".format(
                " ".join(args)), file=sys.stderr)

    def console_clear(self, args):
        self.client().console_clear()

    def console_dump(self, args=None):
        data = self.client().console_dump()
        if data is not None:
            sys.stdout.write(data)
            sys.stdout.flush()

    def console_flush(self, args=None):
        data = self.client().console_flush()
        if data is not None:
            sys.stdout.write(data)
            sys.stdout.flush()

    def console_head(self, args):
        line = self.client().console_head()
        if line is not None:
            sys.stdout.write(line)
            sys.stdout.flush()

    def console_lines(self, args):
        lines = self.client().console_lines()
        sys.stdout.write("%d\n" % (lines))
        sys.stdout.flush()

    def console_interactive(self, args=None):
        client = self.agent
        server = self.client()

        # Print target information
        if sys.stdin.isatty():
            self.target_info()

        # Connect to the consoles
        client.console_remote(self.remote, self.screen)
        client.monitor_remote(self.remote, self.screen)

        client.console_init()

        # Get prefix key
        prefix_key = None
        if sys.stdin.isatty():
            prefix_key = client.console_prefix_key()

        # Input loop
        while self.exiting is False:
            c = client.console_getkey()
            if prefix_key is not None and c == prefix_key:
                c = client.console_getkey()
                self.console_menukey(c)
            elif self.channel == 'console':
                server.console_send(c, True)
            else:
                server.monitor_send(c, True)

        print("\r\nThank you for using MTDA!\r\n\r\n")

    def console_menukey(self, c):
        client = self.agent
        server = self.client()
        if c == 'a':
            status = server.target_lock()
            if status is True:
                server.console_print("\r\n*** Target was acquired ***\r\n")
        elif c == 'b':
            self.console_pastebin()
        elif c == 'c':
            if self.screen.capture_enabled() is False:
                self.screen.print(b"\r\n*** Screen capture started... ***\r\n")
                self.screen.capture_start()
            else:
                self.screen.capture_stop()
                self.screen.print(b"\r\n*** Screen capture stopped ***\r\n")
        elif c == 'i':
            self.target_info()
        elif c == 'm':
            if self.channel == 'console':
                # Switch the alternate screen buffer
                print("\x1b[?1049h")  # same as tput smcup
                self.channel = 'monitor'
            else:
                # Return to the main screen buffer
                print("\x1b[?1049l")  # same as tput rmcup
                self.channel = 'console'
            client.console_toggle()
        elif c == 'p':
            previous_status = server.target_status()
            server.target_toggle()
            new_status = server.target_status()
            if previous_status != new_status:
                server.console_print(
                    "\r\n*** Target is now %s ***\r\n" % (new_status))
        elif c == 'q':
            self.screen.capture_stop()
            self.exiting = True
        elif c == 'r':
            status = server.target_unlock()
            if status is True:
                server.console_print("\r\n*** Target was released ***\r\n")
        elif c == 's':
            previous_status, writing, written = server.storage_status()
            server.storage_swap()
            new_status, writing, written = server.storage_status()
            if new_status != previous_status:
                server.console_print(
                    "\r\n*** Storage now connected to "
                    "%s ***\r\n" % (new_status))
        elif c == 't':
            server.toggle_timestamps()
        elif c == 'u':
            server.usb_toggle(1)

    def console_pastebin(self):
        client = self.agent
        server = self.client()
        api_key = client.pastebin_api_key()
        endpoint = client.pastebin_endpoint()
        if api_key is None or endpoint is None:
            server.console_print(
                    "\r\n*** key/endpoint for pastebin "
                    "are not configured! ***\r\n")
            return
        data = {
                'api_dev_key': api_key,
                'api_option': 'paste',
                'api_paste_code': self.client().console_dump(),
                'api_paste_format': 'python'
               }
        r = requests.post(url=endpoint, data=data)
        server = self.client()
        server.console_print(
            "\r\n*** console buffer posted to %s ***\r\n" % (r.text))

    def console_prompt(self, args):
        prompt = None
        if len(args) > 0:
            prompt = args[0]
        data = self.client().console_prompt(prompt)
        if data is not None:
            sys.stdout.write(data)
            sys.stdout.flush()

    def console_raw(self, args=None):
        client = self.agent
        server = self.client()

        # Connect to the console
        client.console_remote(self.remote, self.screen)

        # Input loop
        client.console_init()
        self.console_dump()
        while self.exiting is False:
            c = client.console_getkey()
            server.console_send(c, True)

    def console_run(self, args):
        data = self.client().console_run(args[0])
        if data is not None:
            sys.stdout.write(data)
            sys.stdout.flush()

    def console_send(self, args):
        self.client().console_send(args[0])

    def console_tail(self, args):
        line = self.client().console_tail()
        if line is not None:
            sys.stdout.write(line)
            sys.stdout.flush()

    def console_wait(self, args):
        timeout = None
        if len(args) == 2:
            timeout = int(args[1])
        result = self.client().console_wait(args[0], timeout)
        return 0 if result is True else 1

    def console_help(self, args=None):
        print("The 'console' command accepts the following sub-commands:")
        print("   clear         Clear any data present in the console buffer")
        print("   dump          Dump content of the console buffer")
        print("   flush         Flush content of the console buffer")
        print("   head          Fetch and print the first line from the"
              " console buffer")
        print("   interactive   Open the device console for interactive use")
        print("   lines         Print number of lines present in the console"
              " buffer")
        print("   prompt        Configure or print the target shell prompt")
        print("   raw           Open the console for use from scripts")
        print("   run           Run the specified command via the device"
              " console")
        print("   send          Send characters to the device console")
        print("   tail          Fetch and print the last line from the console"
              " buffer")
        print("   wait          Wait for the specified string on the console")

    def console_cmd(self, args):
        if len(args) > 0:
            cmd = args[0]
            args.pop(0)

            cmds = {
               'help': self.console_help,
               'clear': self.console_clear,
               'dump': self.console_dump,
               'flush': self.console_flush,
               'head': self.console_head,
               'interactive': self.console_interactive,
               'lines': self.console_lines,
               'prompt': self.console_prompt,
               'raw': self.console_raw,
               'run': self.console_run,
               'send': self.console_send,
               'tail': self.console_tail,
               'wait': self.console_wait
            }

            if cmd in cmds:
                return cmds[cmd](args)
            else:
                print("unknown console command '%s'!" % (cmd), file=sys.stderr)

    def getenv_cmd(self, args=None):
        if len(args) == 0:
            print("'getenv' expects a key argument!", file=sys.stderr)
            return 1

        name = args[0]
        value = self.agent.env_get(name)
        if value is not None:
            print(str(value))
        return 0

    def help_cmd(self, args=None):
        status = 0
        if args is not None and len(args) > 0:
            cmd = args[0]
            args.pop(0)

            cmds = {
               'console': self.console_help,
               'storage': self.storage_help,
               'target': self.target_help,
               'usb': self.usb_help
            }

            if cmd in cmds:
                cmds[cmd](args)
            else:
                print("no help found for command '%s'!" % (cmd),
                      file=sys.stderr)
                status = 1
        else:
            print("usage: mtda [options] <command> [<args>]")
            print("")
            print("The most commonly used mtda commands are:")
            print("   command   Send a command (string) to the device")
            print("   console   Interact with the device console")
            print("   getenv    Get named variable from the environment")
            print("   monitor   Interact with the device monitor console"
                  " (if any)")
            print("   target    Power control the device")
            print("   setenv    Set named variable to specified value in the"
                  " environment")
            print("   storage   Interact with the shared storage device")
            print("   usb       Control USB devices attached to the device")
            print("")
        return status

    def keyboard_help(self, args=None):
        print("The 'keyboard' command accepts the following sub-commands:")
        print("   write         Write a string of characters via the keyboard")

    def keyboard_write(self, args):
        self.client().keyboard_write(args[0])

    def keyboard_cmd(self, args):
        if len(args) > 0:
            cmd = args[0]
            args.pop(0)

            cmds = {
               'help': self.keyboard_help,
               'write': self.keyboard_write
            }

            if cmd in cmds:
                cmds[cmd](args)
            else:
                print("unknown keyboard command '%s'!" % (cmd),
                      file=sys.stderr)

    def monitor_cmd(self, args):
        if len(args) > 0:
            cmd = args[0]
            args.pop(0)

            cmds = {
               'help': self.monitor_help,
               'send': self.monitor_send,
               'wait': self.monitor_wait
            }

            if cmd in cmds:
                return cmds[cmd](args)
            else:
                print("unknown monitor command '%s'!" % (cmd), file=sys.stderr)

    def monitor_help(self, args=None):
        print("The 'monitor' command accepts the following sub-commands:")
        print("   send          Send characters to the device monitor")
        print("   wait          Wait for the specified string on the monitor")

    def monitor_send(self, args):
        self.client().monitor_send(args[0])

    def monitor_wait(self, args):
        timeout = None
        if len(args) == 2:
            timeout = int(args[1])
        result = self.client().monitor_wait(args[0], timeout)
        return 0 if result is True else 1

    def setenv_cmd(self, args=None):
        if len(args) == 0:
            print("'setenv' expects a key argument!", file=sys.stderr)
            return 1

        name = args[0]
        value = args[1] if len(args) > 1 else None

        self.agent.env_set(name, value)
        return 0

    def storage_cmd(self, args):
        if len(args) > 0:
            cmd = args[0]
            args.pop(0)

            cmds = {
               'help': self.storage_help,
               'host': self.storage_host,
               'mount': self.storage_mount,
               'target': self.storage_target,
               'update': self.storage_update,
               'write': self.storage_write
            }

            if cmd in cmds:
                return cmds[cmd](args)
            else:
                print("unknown 'storage' command '%s'!" % (cmd),
                      file=sys.stderr)
                return 1

    def storage_help(self, args=None):
        print("The 'storage' command accepts the following sub-commands:")
        print("   host      Attach the shared storage device to the host")
        print("   mount     Mount the shared storage device on the host")
        print("   target    Attach the shared storage device to the target")
        print("   update    Update the specified file on the shared storage"
              " device")
        print("   write     Write an image to the shared storage device")

    def storage_host(self, args=None):
        status = self.client().storage_to_host()
        if status is False:
            print("failed to connect the shared storage device to the host!",
                  file=sys.stderr)
            return 1
        return 0

    def storage_mount(self, args=None):
        status = self.storage_host()
        if status != 0:
            return 1
        part = args[0] if len(args) > 0 else None
        status = self.agent.storage_mount(part)
        if status is False:
            print("'storage mount' failed!", file=sys.stderr)
            return 1
        return 0

    def storage_target(self, args):
        status = self.client().storage_to_target()
        if status is False:
            print("failed to connect the shared storage device to the target!",
                  file=sys.stderr)
            return 1
        return 0

    def _human_readable_size(self, size):
        if size < 1024*1024:
            return "{:d} KiB".format(int(size/1024))
        elif size < 1024*1024*1024:
            return "{:d} MiB".format(int(size/1024/1024))
        else:
            return "{:.2f} GiB".format(size/1024/1024/1024)

    def _storage_write_cb(self, imgname, totalread, imgsize):
        totalwritten = self.agent.storage_bytes_written()
        progress = int((float(totalread) / float(imgsize)) * float(100))
        blocks = int(round((20 * progress) / 100))
        spaces = ' ' * (20 - blocks)
        blocks = '#' * blocks
        elapsed = time.monotonic() - self.start_time
        speed = round((totalwritten / 1024 / 1024) / elapsed, 2)
        totalread = self._human_readable_size(totalread)
        totalwritten = self._human_readable_size(totalwritten)
        sys.stdout.write("\r{0}: [{1}] {2}% ({3} read, "
                         "{4} written, {5} MiB/s) ".format(
                             imgname, str(blocks + spaces), progress,
                             totalread, totalwritten, speed))
        sys.stdout.flush()

    def storage_update(self, args=None):
        if len(args) == 0:
            print("'storage update' expects a file argument!", file=sys.stderr)
            return 1

        dest = args[0]
        src = args[1] if len(args) > 1 else None

        self.start_time = time.monotonic()
        status = self.agent.storage_update(dest, src, self._storage_write_cb)
        sys.stdout.write("\n")
        sys.stdout.flush()

        if status is False:
            print("'storage update' failed!", file=sys.stderr)
            return 1
        return 0

    def storage_write(self, args=None):
        if len(args) == 0:
            print("'storage write' expects a file argument!", file=sys.stderr)
            return 1

        self.start_time = time.monotonic()
        status = self.agent.storage_write_image(
            args[0], self._storage_write_cb)
        sys.stdout.write("\n")
        sys.stdout.flush()

        if status is False:
            print("'storage write' failed!", file=sys.stderr)
            return 1
        return 0

    def target_help(self, args=None):
        print("The 'target' command accepts the following sub-commands:")
        print("   on      Power on the device")
        print("   off     Power off the device")
        print("   reset   Power cycle the device")
        print("   toggle  Toggle target power")
        print("   uptime  Print target uptime")

    def target_uptime(self):
        result = ""
        uptime = self.client().target_uptime()
        days = int(uptime / (24 * 60 * 60.0))
        if days > 0:
            result = result + " %d days" % int(days)
            uptime = uptime % (24 * 60 * 60.0)
        hours = int(uptime / (60 * 60.0))
        if hours > 0:
            result = result + " %d hours" % int(hours)
            uptime = uptime % (60 * 60.0)
        minutes = int(uptime / 60.0)
        if minutes > 0:
            result = result + " %d minutes" % int(minutes)
            uptime = uptime % 60.0
        seconds = int(uptime)
        if seconds > 0:
            result = result + " %d seconds" % int(seconds)
        return result.strip()

    def target_info(self, args=None):
        sys.stdout.write("\rFetching target information...\r")
        sys.stdout.flush()

        # Get general information
        client = self.client()
        locked = " (locked)" if client.target_locked() else ""
        remote = "Local" if self.remote is None else self.remote
        session = client.session()
        storage_status, writing, written = client.storage_status()
        writing = "WRITING" if writing is True else "IDLE"
        written = self._human_readable_size(written)
        tgt_status = client.target_status()
        uptime = ""
        if tgt_status == "ON":
            uptime = " (up %s)" % self.target_uptime()
        try:
            remote_version = client.agent_version()
        except (zerorpc.RemoteError) as e:
            if e.name == 'NameError':
                remote_version = "<=0.5"
            else:
                raise e

        host = MultiTenantDeviceAccess()
        prefix_key = chr(ord(client.console_prefix_key()) + ord('a') - 1)

        # Print general information
        print("Host           : %s (%s)%30s\r" % (
              socket.gethostname(), host.version, ""))
        print("Remote         : %s (%s)%30s\r" % (
              remote, remote_version, ""))
        print("Prefix key:    : ctrl-%s\r" % (prefix_key))
        print("Session        : %s\r" % (session))
        print("Target         : %-6s%s%s\r" % (tgt_status, locked, uptime))
        print("Storage on     : %-6s%s\r" % (storage_status, locked))
        print("Storage writes : %s (%s)\r" % (written, writing))

        # Print status of the USB ports
        ports = client.usb_ports()
        for ndx in range(0, ports):
            status = client.usb_status(ndx+1)
            print("USB #%-2d        : %s\r" % (ndx+1, status))

        # Print video stream details
        url = client.video_url()
        if url is not None:
            print("Video stream   : %s\r" % (url))

    def target_off(self, args=None):
        status = self.client().target_off()
        return 0 if (status is True) else 1

    def target_on(self, args=None):
        status = self.client().target_on()
        return 0 if (status is True) else 1

    def target_reset(self, args=None):
        status = self.client().target_status()
        if status != "OFF":
            status = self.target_off()
            if status != 0:
                return status
            time.sleep(5)
        status = self.target_on()
        return status

    def target_toggle(self, args=None):
        previous_status = self.client().target_status()
        self.client().target_toggle()
        new_status = self.client().target_status()
        return 0 if (new_status != previous_status) else 1

    def target_uptime_cmd(self, args=None):
        uptime = self.target_uptime()
        print(uptime)
        return 0

    def target_cmd(self, args):
        if len(args) > 0:
            cmd = args[0]
            args.pop(0)

            cmds = {
               'help': self.target_help,
               'off': self.target_off,
               'on': self.target_on,
               'reset': self.target_reset,
               'toggle': self.target_toggle,
               'uptime': self.target_uptime_cmd
            }

            if cmd in cmds:
                return cmds[cmd](args)
            else:
                print("unknown target command '%s'!" % (cmd), file=sys.stderr)
                return 1

    def usb_cmd(self, args):
        if len(args) > 0:
            cmd = args[0]
            args.pop(0)

            cmds = {
               'help': self.usb_help,
               'off': self.usb_off,
               'on': self.usb_on
            }

            if cmd in cmds:
                return cmds[cmd](args)
            else:
                print("unknown usb command '%s'!" % (cmd), file=sys.stderr)
                return 1

    def usb_help(self, args=None):
        print("The 'usb' command accepts the following sub-commands:")
        print("   on      Power on the specified USB device")
        print("   off     Power off the specified USB device")

    def usb_off(self, args):
        if len(args) > 0:
            klass = args[0]
            args.pop(0)
            client = self.client()
            result = client.usb_off_by_class(klass)
            return 0 if result else 1
        else:
            print("missing class argument to 'usb off' command!",
                  file=sys.stderr)
            return 1

    def usb_on(self, args):
        if len(args) > 0:
            klass = args[0]
            args.pop(0)
            client = self.client()
            result = client.usb_on_by_class(klass)
            return 0 if result else 1
        else:
            print("missing class argument to 'usb on' command!",
                  file=sys.stderr)
            return 1

    def print_version(self):
        agent = MultiTenantDeviceAccess()
        print("MTDA version: %s" % agent.version)

    def main(self):
        config = None
        daemonize = False
        detach = True

        options, stuff = getopt.getopt(
            sys.argv[1:], 'c:dhnr:v',
            ['daemon', 'help', 'no-detach', 'remote=', 'version'])
        for opt, arg in options:
            if opt in ('-c', '--config'):
                config = arg
            if opt in ('-d', '--daemon'):
                daemonize = True
            if opt in ('-h', '--help'):
                self.help_cmd()
                sys.exit(0)
            if opt in ('-n', '--no-detach'):
                detach = False
            if opt in ('-r', '--remote'):
                self.remote = arg
            if opt in ('-v', '--version'):
                self.print_version()
                sys.exit(0)

        # Start our server
        if daemonize is True:
            self.agent = MultiTenantDeviceAccess()
            self.agent.load_config(None, daemonize, config)
            self.remote = self.agent.remote
            if detach is True:
                status = self.daemonize()
            else:
                status = self.server()
            if status is False:
                print('Failed to start the MTDA server!', file=sys.stderr)
                return False
            return True

        # Assume we want an interactive console if called without a command
        if len(stuff) == 0:
            stuff = ['console', 'interactive']

        # Check for non-option arguments
        cmd = stuff[0]
        stuff.pop(0)

        cmds = {
            'command': self.command_cmd,
            'console': self.console_cmd,
            'getenv': self.getenv_cmd,
            'help': self.help_cmd,
            'keyboard': self.keyboard_cmd,
            'monitor': self.monitor_cmd,
            'setenv': self.setenv_cmd,
            'storage': self.storage_cmd,
            'target': self.target_cmd,
            'usb': self.usb_cmd
        }

        if cmd in cmds:
            if cmd != 'help':
                # Start our agent
                self.agent = Client(self.remote, config_files=config)
                self.remote = self.agent.remote()
                self.agent.start()
            status = cmds[cmd](stuff)
            sys.exit(status)
        else:
            print("unknown command '%s'!" % (cmd), file=sys.stderr)
            self.help_cmd()
            sys.exit(1)
        return True


if __name__ == '__main__':
    app = Application()
    app.main()
