#!/usr/bin/env python3

import os
import shutil
import sys
from typing import Iterable, NoReturn, Optional

# TODO(raz): move some of the following functions to a common library


def which(cmd: str) -> Optional[str]:
    """
    Find the original path of a genv shim.
    """
    shims = os.path.realpath(os.path.dirname(__file__))
    path = os.pathsep.join(
        dir
        for dir in os.environ.get("PATH").split(os.pathsep)
        if os.path.realpath(dir) != shims
    )

    return shutil.which(cmd, path=path)


def attached_devices() -> Iterable[int]:
    return [
        index
        for index in (
            int(index) for index in os.environ.get("CUDA_VISIBLE_DEVICES").split(",")
        )
        if index != -1
    ]


def find_gpus_arg(docker_args: Iterable[str]) -> Optional[int]:
    for index, arg in enumerate(docker_args):
        if arg == "--gpus":
            return index + 1 if index < len(docker_args) - 1 else None
        elif arg.startswith("--gpus="):
            return index


def get_gpus_arg(docker_args: Iterable[str], index: int) -> str:
    return docker_args[index].split("--gpus=")[-1]


def set_gpus_arg(docker_args: Iterable[str], index: int, value: str) -> None:
    docker_args[index] = (
        f"--gpus={value}" if docker_args[index].startswith("--gpus=") else value
    )


def exec(docker_args: Iterable[str]) -> NoReturn:
    """
    Execute the original 'docker' command with the specified arguments.
    This calls Linux exec syscall and does not return.
    """
    path = which("docker")
    args = [path] + docker_args

    os.execv(path, args)


if __name__ == "__main__":
    docker_args = sys.argv[1:]

    index = find_gpus_arg(docker_args)
    bypass = int(os.environ.get("GENV_BYPASS", 0)) != 0

    if index is not None and not bypass:
        value = get_gpus_arg(docker_args, index)

        indices = attached_devices()

        if value.isnumeric():
            value = int(value)

            if value > len(indices):
                print(
                    f"Requested {value} devices but only {len(indices)} are attached to environment",
                    file=sys.stderr,
                )
                sys.exit(1)

            indices = indices[:value]
        elif value != "all":
            print(f"Unsupported value for '--gpus' ({value})", file=sys.stderr)
            sys.exit(1)

        set_gpus_arg(
            docker_args,
            index,
            "0"
            if len(indices) == 0
            else f"\"device={','.join(str(index) for index in indices)}\"",
        )

        docker_args[index + 1 : index + 1] = [
            "-e",
            f'GENV_ENVIRONMENT_ID={os.environ.get("GENV_ENVIRONMENT_ID")}',
        ]

    exec(docker_args)
