#!/usr/bin/env python3

from collections import defaultdict
from enum import Enum
import os
import re
import shutil
import subprocess
import sys
from typing import Dict, Iterable, Optional

# TODO(raz): move some of the following functions to a common library


def which(cmd: str) -> str:
    """
    Find the original path of a genv shim.
    Raises 'RuntimeError' if not found.
    """
    shims = os.path.realpath(os.path.dirname(__file__))
    path = os.pathsep.join(
        dir
        for dir in os.environ.get("PATH").split(os.pathsep)
        if os.path.realpath(dir) != shims
    )

    result = shutil.which(cmd, path=path)

    if not result:
        raise RuntimeError(f"Could not find the path of '{cmd}'")

    return result


# TODO(raz): use this method from genv.utils.os_
def get_process_environ(pid: int) -> Dict[str, str]:
    """
    Returns the environment variables of the process with the given identifier.
    Raises 'FileNotFoundError' if no such process and 'PermissionError' if there is no sufficient permissions.
    """
    with open(f"/proc/{pid}/environ", "r") as f:
        return {
            variable: value
            for variable, value in (
                line.split("=", 1) for line in f.read().split("\x00") if line
            )
        }


def find_process_eid(pid: int) -> Optional[str]:
    """
    Get the environment identifier of a process if it is running
    in an activated environment.
    Returns 'None' otherwise.
    """
    try:
        return get_process_environ(pid).get("GENV_ENVIRONMENT_ID", None)
    except (FileNotFoundError, PermissionError):
        return None


def memory_cap_to_mib(cap: str) -> int:
    """
    Convert the environment memory capacity string to an integer value in MiB.
    """

    def _bytes(cap: str) -> int:
        for unit, multiplier in [
            ("b", 1),
            ("k", 1000),
            ("m", 1000 * 1000),
            ("g", 1000 * 1000 * 1000),
            ("ki", 1024),
            ("mi", 1024 * 1024),
            ("gi", 1024 * 1024 * 1024),
        ]:
            if cap.endswith(unit):
                return int(cap.replace(unit, "")) * multiplier

        return int(cap)  # the value is already in bytes if no unit was specified

    return _bytes(cap) // (1024 * 1024)


def print_warning_message(message: str) -> None:
    print("\n" + f"WARNING: {message}" + "\n")


def proxy(process: subprocess.Popen) -> None:
    """
    Reprint the output of a running nvidia-smi process while manipulating it.
    """
    DEVICE = re.compile(r"(^\|.*)\| \s*(\d+)MiB / \s*(\d+)MiB \|(.*\|$)")
    PROCESS = [
        re.compile(r"^\|\s*(\d+)\s+(\d+).*?(\d+)MiB \|$"),  # < 11.1
        re.compile(
            r"^\|\s*(\d+)\s+N/A\s+N/A\s+(\d+).*?(\d+)MiB \|$"
        ),  # TODO(raz): support MIG (i.e. non N/A values)
    ]
    NO_PROCESSES_FOUND = re.compile(r"^\|\s*No running processes found\s*\|$")

    EID = os.environ.get("GENV_ENVIRONMENT_ID")

    class Action(Enum):
        Print = 0
        Buffer = 1

    action = Action.Print

    previous = None
    buffered = []
    env_processes = defaultdict(list)
    total_processes = 0
    no_running_processes_found = False

    while process.poll() is None:
        line = process.stdout.readline()

        if action == Action.Print and DEVICE.match(line):
            action = Action.Buffer

        if action == Action.Print:
            print(line, end="")
        elif action == Action.Buffer:
            skip = False

            if NO_PROCESSES_FOUND.match(line):
                no_running_processes_found = True
            else:
                for pattern in PROCESS:
                    match = pattern.match(line)
                    if match:
                        index = int(match.group(1))
                        pid = int(match.group(2))
                        mib = int(match.group(3))

                        total_processes += 1

                        if find_process_eid(pid) == EID:
                            env_processes[index].append(mib)
                        else:
                            skip = True

                        break

            if not skip:
                buffered.append((line, previous))

        previous = line

    is_missing_process_information = (
        action == Action.Buffer  # this ensures we are in the default nvidia-smi view
        and total_processes == 0
        and not no_running_processes_found
    )

    is_memory_usage_over_cap = False

    for line, previous in buffered:
        match = DEVICE.match(line)
        if match:
            used = int(match.group(2))
            total = int(match.group(3))

            gpu_memory_cap = os.environ.get("GENV_GPU_MEMORY")
            if gpu_memory_cap:
                total = memory_cap_to_mib(gpu_memory_cap)

            if is_missing_process_information:
                used = "N/A "
            else:
                index = int(re.match(r"^\|\s*(\d+)\s*.*\|$", previous).group(1))
                used = sum(env_processes[index])

                if used > total:
                    is_memory_usage_over_cap = True

            line = DEVICE.sub(
                f"\\1| {str(used).rjust(6)}MiB / {str(total).rjust(5)}MiB |\\4", line
            )

        print(line, end="")

        if (
            total_processes > 0
            and sum(len(env_processes[index]) for index in env_processes) == 0
            and re.match(r"^\|={77}\|$", line)
        ):
            print("|  " + "No running processes found".ljust(75) + "|")

    if is_missing_process_information:
        print_warning_message(
            "Missing information about running processes. Cannot calculate GPU memory usage.\n\nIf you are running in a container, using the host pid namespace would solve this.\nThis could be done by passing '--pid host' to the 'docker run' command."
        )

    if is_memory_usage_over_cap:
        print_warning_message(
            f"The GPU memory usage for this environment has exceeded the set capacity of {os.environ.get('GENV_GPU_MEMORY')}"
        )


def run(args: Iterable[str], bypass: bool = False) -> int:
    """
    Run the original nvidia-smi and manipulate its output.
    Returns the process exit code.
    """
    with subprocess.Popen(
        [which("nvidia-smi")] + args,
        stdout=subprocess.PIPE if not bypass else None,
        text=True,
    ) as process:
        if not bypass:
            proxy(process)
        else:
            process.wait()
        return process.returncode


bypass = int(os.environ.get("GENV_BYPASS", 0)) != 0

# NOTE(raz): the environment variable "CUDA_VISIBLE_DEVICES" will always be set for shell environments.
# however, it might not be set for other environment types like containers.
if "CUDA_VISIBLE_DEVICES" in os.environ and not bypass:
    # TODO(raz): handle modes of nvidia-smi where '--id' is not supported like 'pmon'
    args = [f'--id={os.environ.get("CUDA_VISIBLE_DEVICES")}']
else:
    args = []

sys.exit(run(args + sys.argv[1:], bypass))
