#!/usr/bin/env python3
# ---------------------------------------------------------------------------
# Command Line Interface for MTDA
# ---------------------------------------------------------------------------
#
# This software is a part of MTDA.
# Copyright (C) 2023 Siemens Digital Industries Software
#
# ---------------------------------------------------------------------------
# SPDX-License-Identifier: MIT
# ---------------------------------------------------------------------------

# System imports
import requests
import time
import sys
import zerorpc
import socket
from argparse import ArgumentParser, RawTextHelpFormatter

# Local imports
from mtda.main import MultiTenantDeviceAccess
from mtda.client import Client
from mtda.console.screen import ScreenOutput


class Application:

    def __init__(self):
        self.agent = None
        self.remote = None
        self.exiting = False
        self.channel = "console"
        self.screen = ScreenOutput(self)

    def client(self):
        return self.agent

    def debug(self, level, msg):
        self.client().debug(level, msg)

    def command_cmd(self, args):
        result = self.client().command(args.cmd_command)
        if isinstance(result, str):
            print(result)
        else:
            print("Device command '{0}' failed!".format(
                " ".join(args)), file=sys.stderr)

    def console_clear(self, args):
        self.client().console_clear()

    def console_dump(self, args=None):
        data = self.client().console_dump()
        if data is not None:
            sys.stdout.write(data)
            sys.stdout.flush()

    def console_flush(self, args=None):
        data = self.client().console_flush()
        if data is not None:
            sys.stdout.write(data)
            sys.stdout.flush()

    def console_head(self, args):
        line = self.client().console_head()
        if line is not None:
            sys.stdout.write(line)
            sys.stdout.flush()

    def console_lines(self, args):
        lines = self.client().console_lines()
        sys.stdout.write("%d\n" % (lines))
        sys.stdout.flush()

    def console_interactive(self, args=None):
        client = self.agent
        server = self.client()

        # Print target information
        if sys.stdin.isatty():
            self.target_info()

        # Connect to the consoles
        client.console_remote(self.remote, self.screen)
        client.monitor_remote(self.remote, self.screen)

        client.console_init()

        # Get prefix key
        prefix_key = None
        if sys.stdin.isatty():
            prefix_key = client.console_prefix_key()

        # Input loop
        while self.exiting is False:
            c = client.console_getkey()
            if prefix_key is not None and c == prefix_key:
                c = client.console_getkey()
                self.console_menukey(c)
            elif self.channel == 'console':
                server.console_send(c, True)
            else:
                server.monitor_send(c, True)

        print("\r\nThank you for using MTDA!\r\n\r\n")

    def console_menukey(self, c):
        client = self.agent
        server = self.client()
        if c == 'a':
            status = server.target_lock()
            if status is True:
                server.console_print("\r\n*** Target was acquired ***\r\n")
        elif c == 'b':
            self.console_pastebin()
        elif c == 'c':
            if self.screen.capture_enabled() is False:
                self.screen.print(b"\r\n*** Screen capture started... ***\r\n")
                self.screen.capture_start()
            else:
                self.screen.capture_stop()
                self.screen.print(b"\r\n*** Screen capture stopped ***\r\n")
        elif c == 'i':
            self.target_info()
        elif c == 'm':
            if self.channel == 'console':
                # Switch the alternate screen buffer
                print("\x1b[?1049h")  # same as tput smcup
                self.channel = 'monitor'
            else:
                # Return to the main screen buffer
                print("\x1b[?1049l")  # same as tput rmcup
                self.channel = 'console'
            client.console_toggle()
        elif c == 'p':
            previous_status = server.target_status()
            server.target_toggle()
            new_status = server.target_status()
            if previous_status != new_status:
                server.console_print(
                    "\r\n*** Target is now %s ***\r\n" % (new_status))
        elif c == 'q':
            self.screen.capture_stop()
            self.exiting = True
        elif c == 'r':
            status = server.target_unlock()
            if status is True:
                server.console_print("\r\n*** Target was released ***\r\n")
        elif c == 's':
            previous_status, writing, written = server.storage_status()
            server.storage_swap()
            new_status, writing, written = server.storage_status()
            if new_status != previous_status:
                server.console_print(
                    "\r\n*** Storage now connected to "
                    "%s ***\r\n" % (new_status))
        elif c == 't':
            server.toggle_timestamps()
        elif c == 'u':
            server.usb_toggle(1)

    def console_pastebin(self):
        client = self.agent
        server = self.client()
        api_key = client.pastebin_api_key()
        endpoint = client.pastebin_endpoint()
        if api_key is None or endpoint is None:
            server.console_print(
                    "\r\n*** key/endpoint for pastebin "
                    "are not configured! ***\r\n")
            return
        data = {
                'api_dev_key': api_key,
                'api_option': 'paste',
                'api_paste_code': self.client().console_dump(),
                'api_paste_format': 'python'
               }
        r = requests.post(url=endpoint, data=data)
        server = self.client()
        server.console_print(
            "\r\n*** console buffer posted to %s ***\r\n" % (r.text))

    def console_prompt(self, args):
        data = self.client().console_prompt(args.prompt)
        if data is not None:
            sys.stdout.write(data)
            sys.stdout.flush()

    def console_raw(self, args=None):
        client = self.agent
        server = self.client()

        # Connect to the console
        client.console_remote(self.remote, self.screen)

        # Input loop
        client.console_init()
        self.console_dump()
        while self.exiting is False:
            c = client.console_getkey()
            server.console_send(c, True)

    def console_run(self, args):
        data = self.client().console_run(args.console_command)
        if data is not None:
            sys.stdout.write(data)
            sys.stdout.flush()

    def console_send(self, args):
        self.client().console_send(args.send_string)

    def console_tail(self, args):
        line = self.client().console_tail()
        if line is not None:
            sys.stdout.write(line)
            sys.stdout.flush()

    def console_wait(self, args):
        result = self.client().console_wait(args.wait_string, args.timeout)
        return 0 if result is True else 1

    def console_cmd(self, args):
        cmds = {
           'clear': self.console_clear,
           'dump': self.console_dump,
           'flush': self.console_flush,
           'head': self.console_head,
           'interactive': self.console_interactive,
           'lines': self.console_lines,
           'prompt': self.console_prompt,
           'raw': self.console_raw,
           'run': self.console_run,
           'send': self.console_send,
           'tail': self.console_tail,
           'wait': self.console_wait
        }

        return cmds[args.subcommand](args)

    def getenv_cmd(self, args=None):
        value = self.agent.env_get(args.name)
        if value is not None:
            print(str(value))
        return 0

    def keyboard_write(self, args):
        self.client().keyboard_write(args.write_key)

    def keyboard_cmd(self, args):
        cmds = {
           'write': self.keyboard_write
        }

        cmds[args.subcommand](args)

    def monitor_cmd(self, args):
        cmds = {
           'send': self.monitor_send,
           'wait': self.monitor_wait
        }

        return cmds[args.subcommand](args)

    def monitor_send(self, args):
        self.client().monitor_send(args.send_string)

    def monitor_wait(self, args):
        result = self.client().monitor_wait(args.wait_string, args.timeout)
        return 0 if result is True else 1

    def setenv_cmd(self, args=None):
        self.agent.env_set(args.name, args.value)
        return 0

    def storage_cmd(self, args):
        cmds = {
           'host': self.storage_host,
           'mount': self.storage_mount,
           'target': self.storage_target,
           'update': self.storage_update,
           'write': self.storage_write
        }

        return cmds[args.subcommand](args)

    def storage_host(self, args=None):
        status = self.client().storage_to_host()
        if status is False:
            print("failed to connect the shared storage device to the host!",
                  file=sys.stderr)
            return 1
        return 0

    def storage_mount(self, args=None):
        status = self.storage_host()
        if status != 0:
            return 1
        status = self.agent.storage_mount(args.partnum)
        if status is False:
            print("'storage mount' failed!", file=sys.stderr)
            return 1
        return 0

    def storage_target(self, args):
        status = self.client().storage_to_target()
        if status is False:
            print("failed to connect the shared storage device to the target!",
                  file=sys.stderr)
            return 1
        return 0

    def _human_readable_size(self, size):
        if size < 1024*1024:
            return "{:d} KiB".format(int(size/1024))
        elif size < 1024*1024*1024:
            return "{:d} MiB".format(int(size/1024/1024))
        else:
            return "{:.2f} GiB".format(size/1024/1024/1024)

    def _storage_write_cb(self, imgname, totalread,
                          inputsize, totalwritten, imagesize):
        if imagesize:
            progress = \
                int((float(totalwritten) / float(imagesize)) * float(100))
        else:
            progress = int((float(totalread) / float(inputsize)) * float(100))
        blocks = int(round((20 * progress) / 100))
        spaces = ' ' * (20 - blocks)
        blocks = '#' * blocks
        elapsed = time.monotonic() - self.start_time
        speed = round((totalwritten / 1024 / 1024) / elapsed, 2)
        totalread = self._human_readable_size(totalread)
        totalwritten = self._human_readable_size(totalwritten)
        sys.stdout.write("\r{0}: [{1}] {2}% ({3} read, "
                         "{4} written, {5} MiB/s) ".format(
                             imgname, str(blocks + spaces), progress,
                             totalread, totalwritten, speed))
        sys.stdout.flush()

    def storage_update(self, args=None):
        self.start_time = time.monotonic()
        status = self.agent.storage_update(
            args.destination, args.source, self._storage_write_cb
        )
        sys.stdout.write("\n")
        sys.stdout.flush()

        if status is False:
            print("'storage update' failed!", file=sys.stderr)
            return 1
        return 0

    def storage_write(self, args=None):
        result = 0
        try:
            self.start_time = time.monotonic()
            self.agent.storage_write_image(
                args.image, self._storage_write_cb
            )
            sys.stdout.write("\n")
            sys.stdout.flush()
        except Exception as e:
            msg = e.msg if hasattr(e, 'msg') else str(e)
            print("\n'storage write' failed! ({})".format(msg),
                  file=sys.stderr)
            result = 1
        return result

    def target_uptime(self):
        result = ""
        uptime = self.client().target_uptime()
        days = int(uptime / (24 * 60 * 60.0))
        if days > 0:
            result = result + " %d days" % int(days)
            uptime = uptime % (24 * 60 * 60.0)
        hours = int(uptime / (60 * 60.0))
        if hours > 0:
            result = result + " %d hours" % int(hours)
            uptime = uptime % (60 * 60.0)
        minutes = int(uptime / 60.0)
        if minutes > 0:
            result = result + " %d minutes" % int(minutes)
            uptime = uptime % 60.0
        seconds = int(uptime)
        if seconds > 0:
            result = result + " %d seconds" % int(seconds)
        return result.strip()

    def target_info(self, args=None):
        sys.stdout.write("\rFetching target information...\r")
        sys.stdout.flush()

        # Get general information
        client = self.client()
        locked = " (locked)" if client.target_locked() else ""
        remote = "Local" if self.remote is None else self.remote
        session = client.session()
        storage_status, writing, written = client.storage_status()
        writing = "WRITING" if writing is True else "IDLE"
        written = self._human_readable_size(written)
        tgt_status = client.target_status()
        uptime = ""
        if tgt_status == "ON":
            uptime = " (up %s)" % self.target_uptime()
        try:
            remote_version = client.agent_version()
        except (zerorpc.RemoteError) as e:
            if e.name == 'NameError':
                remote_version = "<=0.5"
            else:
                raise e

        host = MultiTenantDeviceAccess()
        prefix_key = chr(ord(client.console_prefix_key()) + ord('a') - 1)

        # Print general information
        print("Host           : %s (%s)%30s\r" % (
              socket.gethostname(), host.version, ""))
        print("Remote         : %s (%s)%30s\r" % (
              remote, remote_version, ""))
        print("Prefix key:    : ctrl-%s\r" % (prefix_key))
        print("Session        : %s\r" % (session))
        print("Target         : %-6s%s%s\r" % (tgt_status, locked, uptime))
        print("Storage on     : %-6s%s\r" % (storage_status, locked))
        print("Storage writes : %s (%s)\r" % (written, writing))

        # Print status of the USB ports
        ports = client.usb_ports()
        for ndx in range(0, ports):
            status = client.usb_status(ndx+1)
            print("USB #%-2d        : %s\r" % (ndx+1, status))

        # Print video stream details
        url = client.video_url()
        if url is not None:
            print("Video stream   : %s\r" % (url))

    def target_off(self, args=None):
        status = self.client().target_off()
        return 0 if (status is True) else 1

    def target_on(self, args=None):
        status = self.client().target_on()
        return 0 if (status is True) else 1

    def target_reset(self, args=None):
        status = self.client().target_status()
        if status != "OFF":
            status = self.target_off()
            if status != 0:
                return status
            time.sleep(5)
        status = self.target_on()
        return status

    def target_toggle(self, args=None):
        previous_status = self.client().target_status()
        self.client().target_toggle()
        new_status = self.client().target_status()
        return 0 if (new_status != previous_status) else 1

    def target_uptime_cmd(self, args=None):
        uptime = self.target_uptime()
        print(uptime)
        return 0

    def target_cmd(self, args):
        cmds = {
           'off': self.target_off,
           'on': self.target_on,
           'reset': self.target_reset,
           'toggle': self.target_toggle,
           'uptime': self.target_uptime_cmd
        }

        return cmds[args.subcommand](args)

    def usb_cmd(self, args):
        cmds = {
           'off': self.usb_off,
           'on': self.usb_on
        }

        return cmds[args.subcommand](args)

    def usb_off(self, args):
        klass = args.option
        client = self.client()
        result = client.usb_off_by_class(klass)
        return 0 if result else 1

    def usb_on(self, args):
        klass = args.option
        client = self.client()
        result = client.usb_on_by_class(klass)
        return 0 if result else 1

    def print_version(self):
        agent = MultiTenantDeviceAccess()
        print("MTDA version: %s" % agent.version)

    def main(self):
        parser = ArgumentParser(
            allow_abbrev=False,
            formatter_class=RawTextHelpFormatter,
        )
        parser.add_argument("-v", "--version", action="store_true")
        parser.add_argument(
            "-c", "--config", metavar="", required=False, help="config file"
        )
        parser.add_argument(
            "-r",
            "--remote",
            metavar="",
            default="localhost",
            required=False,
            help="remote",
        )

        # subcommand: command
        subparsers = parser.add_subparsers(dest="command")
        subparsers.required = False
        cmd = self.command_cmd
        p = subparsers.add_parser(
            "command",
            help="Send a command (string) to the device",
        )
        p.add_argument(
            "cmd_command", metavar="cmd", type=str, help="Command to send"
        )
        p.set_defaults(func=cmd)

        # subcommand: console
        cmd = self.console_cmd
        p = subparsers.add_parser(
            "console",
            help="Interact with the device console",
        )
        p.set_defaults(func=cmd)
        subsub = p.add_subparsers(dest="subcommand")
        subsub.required = True
        s = subsub.add_parser(
            "clear",
            help="Clear any data present in the console buffer",
        )
        s = subsub.add_parser(
            "dump",
            help="Dump content of the console buffer",
        )
        s = subsub.add_parser(
            "flush",
            help="Flush content of the console buffer",
        )
        s = subsub.add_parser(
            "head",
            help="Fetch and print the first line from the console buffer",
        )
        s = subsub.add_parser(
            "interactive",
            help="Open the device console for interactive use",
        )
        s = subsub.add_parser(
            "lines",
            help="Print number of lines present in the console buffer",
        )
        s = subsub.add_parser(
            "prompt",
            help="Configure or print the target shell prompt",
        )
        s.add_argument(
            "prompt_str",
            metavar="string",
            nargs="?",
            type=str,
            help="Console prompt string to set",
        )
        s = subsub.add_parser(
            "raw",
            help="Open the console for use from scripts",
        )
        s = subsub.add_parser(
            "run",
            help="Run the specified command via the device console",
        )
        s.add_argument(
            "console_command",
            metavar="command",
            type=str,
            help="Command to run"
        )
        s = subsub.add_parser(
            "send",
            help="Send characters to the device console",
        )
        s.add_argument("send_string", type=str, help="String to send")
        s = subsub.add_parser(
            "tail",
            help="Fetch and print the last line from the console buffer",
        )
        s = subsub.add_parser(
            "wait",
            help="Wait for the specified string on the console",
        )
        s.add_argument("wait_string", type=str, help="String to wait")
        s.add_argument("timeout", type=int, nargs="?", help="Timeout")

        # subcommand: getenv
        cmd = self.getenv_cmd
        p = subparsers.add_parser(
            "getenv",
            help="Get named variable from the environment",
        )
        p.add_argument("name", type=str, help="Key")
        p.set_defaults(func=cmd)

        # subcommand: keyboard
        cmd = self.keyboard_cmd
        p = subparsers.add_parser(
            "keyboard",
            help="Write a string of characters via the keyboard",
        )
        p.set_defaults(func=cmd)
        subsub = p.add_subparsers(dest="subcommand")
        subsub.required = True
        s = subsub.add_parser(
            "write", help="Write a string of characters via the keyboard"
        )
        s.add_argument(
            "write_key",
            metavar="key",
            type=str,
            help="Keyboard Key"
        )

        # subcommand: monitor
        cmd = self.monitor_cmd
        p = subparsers.add_parser(
            "monitor",
            help="Interact with the device monitor console (if any)",
        )
        p.set_defaults(func=cmd)
        subsub = p.add_subparsers(dest="subcommand")
        subsub.required = True
        s = subsub.add_parser(
            "send",
            help="Send characters to the device monitor"
        )
        s.add_argument("send_string", type=str, help="String to send")
        s = subsub.add_parser(
            "wait", help="Wait for the specified string on the monitor"
        )
        s.add_argument("wait_string", type=str, help="String to wait")
        s.add_argument("timeout", type=int, nargs="?", help="Timeout")

        # subcommand: setenv
        cmd = self.setenv_cmd
        p = subparsers.add_parser(
            "setenv",
            help="Set named variable to specified value in the environment",
        )
        p.add_argument("name", type=str, help="Key")
        p.add_argument("value", type=str, nargs="?", help="Value")
        p.set_defaults(func=cmd)

        # subcommand: storage
        cmd = self.storage_cmd
        p = subparsers.add_parser(
            "storage",
            help="Interact with the shared storage device",
        )
        p.set_defaults(func=cmd)
        subsub = p.add_subparsers(dest="subcommand")
        subsub.required = True
        s = subsub.add_parser(
            "host", help="Attach the shared storage device to the host"
        )
        s = subsub.add_parser(
            "mount", help="Mount the shared storage device on the host"
        )
        s.add_argument(
            "partnum",
            metavar="partition number",
            type=int,
            nargs="?",
            help="Parititon number to mount",
        )
        s = subsub.add_parser(
            "target", help="Attach the shared storage device to the target"
        )
        s = subsub.add_parser(
            "update",
            help="Update the specified file on the shared storage device"
        )
        s.add_argument(
            "destination",
            type=str,
            help="Location in mounted partition"
        )
        s.add_argument(
            "source",
            metavar="filename",
            type=str, nargs="?",
            help="Path to file"
        )
        s = subsub.add_parser(
            "write",
            help="Write an image to the shared storage device"
        )
        s.add_argument(
            "image",
            metavar="image",
            type=str,
            help="Path to image file"
        )

        # subcommand: target
        cmd = self.target_cmd
        p = subparsers.add_parser(
            "target",
            help="Power control the device",
        )
        p.set_defaults(func=cmd)
        subsub = p.add_subparsers(dest="subcommand")
        subsub.required = True
        s = subsub.add_parser("on", help="Power on the device")
        s = subsub.add_parser("off", help="Power off the device")
        s = subsub.add_parser("reset", help="Power cycle the device")
        s = subsub.add_parser("toggle", help="Toggle target power")
        s = subsub.add_parser("uptime", help="Print target uptime")

        # subcommand: usb
        cmd = self.usb_cmd
        p = subparsers.add_parser(
            "usb",
            help="Control USB devices attached to the device",
        )
        p.set_defaults(func=cmd)
        subsub = p.add_subparsers(dest="subcommand")
        p.add_argument(
            "option",
            metavar="class",
            type=str,
            help="usb class type"
        )
        s = subsub.add_parser("on", help="Power on the specified USB device")
        s = subsub.add_parser(
            "off", help="Power off the specified USB device"
        )

        args = parser.parse_args()
        if args.version:
            self.print_version()
            sys.exit(0)
        config = args.config
        self.remote = args.remote
        self.agent = Client(self.remote, config_files=config)
        self.remote = self.agent.remote()
        self.agent.start()
        # Assume we want an interactive console if called without a command
        if args.command is None:
            args = parser.parse_args(['console', 'interactive'])
        status = args.func(args)
        sys.exit(status)


if __name__ == '__main__':
    app = Application()
    app.main()
