#!/usr/bin/python
# SPDX-License-Identifier: MIT

import argparse
import base64
import json
import socket
import sys
import time


class QemuException(BaseException):
    def __init__(self, message):
        super(QemuException, self).__init__(message)


VERBOSE_USER = 1
VERBOSE_DEV = 2


def verbose(options, level, *args, **kwargs):
    if options.verbose >= level:
        print(file=sys.stderr, *args, **kwargs)


def parse_options():
    parser = argparse.ArgumentParser(
        description="Command executor for QEMU Windows guests."
    )
    parser.add_argument(
        "--connect", "-c", required=True, help="virtio-serial Unix socket path"
    )
    parser.add_argument(
        "--verbose", "-v", action="count", help="increase output verbosity"
    )
    parser.add_argument(
        "--env",
        "-e",
        action="append",
        metavar="NAME=VALUE",
        default=[],
        help="guest environment variable",
    )
    parser.add_argument(
        "--path", "-p", action="append", default=[], help="guest PATH component"
    )
    parser.add_argument(
        "--no-default-path",
        "-P",
        action="store_true",
        help="do not prepend PATH with system folders",
    )
    parser.add_argument(
        "--no-default-env",
        "-E",
        action="store_true",
        help="do not prepend environment with default values",
    )
    parser.add_argument("command", metavar="COMMAND", nargs="+", help="guest command")
    options = parser.parse_args()

    path = options.path
    if not options.no_default_path:
        default_path = [
            r"C:\Windows\system32",
            r"C:\Windows",
        ]
        path = default_path + path

    if not options.no_default_env:
        options.env += [
            r"SystemDrive=C:",
            r"SystemRoot=C:\Windows",
            r"ProgramData=C:\ProgramData",
            r"ProgramFiles=C:\Program Files",
            r"ProgramFiles(x86)=C:\Program Files (x86)",
        ]
    options.env.append(f"PATH={';'.join(path)}")

    return options


def run(channel, options):
    def send(command, args=None):
        args = args if args else {}
        request = {"execute": command, "arguments": args}
        line = bytes(f"{json.dumps(request)}\n", encoding="utf-8")
        verbose(options, VERBOSE_DEV, "Sending: ", line)
        channel.sendall(line)

    def parse(line):
        response = json.loads(line[:-1])
        result = response.get("return")
        if result is None:
            raise QemuException(response["error"])
        return result

    def receive():
        line = b""
        while True:
            part = channel.recv(2048)
            verbose(options, VERBOSE_DEV, "Received: ", part)
            line += part
            if line.endswith(b"\n"):
                return parse(line)

    def do(command, args=None):
        send(command, args)
        return receive()

    def decode(status, key):
        value = status.get(key)
        if value is not None:
            return base64.b64decode(value)
        return b""

    process = do(
        "guest-exec",
        {
            "path": options.command[0],
            "arg": options.command[1:],
            "env": options.env,
            "capture-output": True,
        },
    )

    while True:
        status = do("guest-exec-status", {"pid": process["pid"]})
        if status["exited"]:
            break
        time.sleep(1)

    code = status["exitcode"]
    stdout = decode(status, "out-data")
    stderr = decode(status, "err-data")
    return (code, stdout, stderr)


def main():
    options = parse_options()
    if options.verbose:
        verbose(options, VERBOSE_USER, "Connect to:", options.connect)
        verbose(
            options,
            VERBOSE_USER,
            "Guest command:",
            " ".join([repr(arg) for arg in options.command]),
        )
        verbose(options, VERBOSE_USER, "Guest environment:")
        for env in options.env:
            verbose(options, VERBOSE_USER, "\t", env, sep="")

    with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM, 0) as channel:
        verbose(options, VERBOSE_USER, "Connecting... ", end="")
        channel.connect(options.connect)
        verbose(options, VERBOSE_USER, "OK.")

        code, stdout, stderr = run(channel, options)
        sys.stdout.buffer.write(stdout)
        sys.stderr.buffer.write(stderr)
        sys.exit(code)


if __name__ == "__main__":
    main()
