#!/usr/bin/env python

from subprocess import run, PIPE
from argparse import ArgumentParser
from gpu_utils.utils import get_n_gpus, get_running_pids


def kill_interrupted_processes(sudo: bool = False) -> None:
    """
    Kill processes which `nvidia-smi` doesn't show but which are still using GPUs.

    Interrupting a process can sometimes lead to it not properly releasing GPU memory. This
    function tries to regain that memory by killing any processes that `nvidia-smi` doesn't
    show as running but that you can still see are using the GPUs using `lsof /dev/nvidia{gpu_id}`.
    `kill -9 {pid}` will be called for all such processes.

    :param sudo: if True, use sudo to find and kill interruped processes belong to all users
    """
    sudo = "sudo " if sudo else ""
    n_gpus = get_n_gpus()
    all_pids = set()
    for gpu_id in range(n_gpus):
        stdout = run(
            (sudo + f"lsof -t /dev/nvidia{gpu_id}").split(" "), stdout=PIPE
        ).stdout
        pids = stdout.decode().split()
        all_pids.update(pids)

    running_pids = get_running_pids()
    kill_pids = all_pids.difference(running_pids)
    run([*(sudo + "kill -9").split(" "), *kill_pids])


def main():
    parser = ArgumentParser(description="""Kill processes which `nvidia-smi` doesn't show but which are still using GPUs.

    Interrupting a process can sometimes lead to it not properly releasing GPU memory. This
    function tries to regain that memory by killing any processes that `nvidia-smi` doesn't
    show as running but that you can still see are using the GPUs using `lsof /dev/nvidia{gpu_id}`.
    `kill -9 {pid}` will be called for all such processes.""")
    parser.add_argument("-s", "--sudo", action="store_true", help="if True, use sudo to find and kill interruped processes belong to all users")

    args = parser.parse_args()
    kill_interrupted_processes(args.sudo)


if __name__ == "__main__":
    main()
