#!/usr/bin/env python

import importlib
import argparse
import sys
nvidia_smi2 = importlib.import_module("nvidia-smi2")

parser = argparse.ArgumentParser()
parser.add_argument('-l', '--command-length', default=20, const=100, type=int, nargs='?')
parser.add_argument('-c', '--color', action='store_true')
parser.add_argument('-u', '--user', type = str, default = '', help = 'user to analyst instead of all user')

args = parser.parse_args()

# parse the command length argument
command_length = args.command_length
color = args.color

# Get nvidia-smi stdout
lines = nvidia_smi2.get_nvidia_smi_stdout()

# Get line to print
lines_to_print, ps_start_idx, is_new_format = nvidia_smi2.get_line_to_print(lines)

# Colorize
if color:
    lines_to_print = nvidia_smi2.colorize(lines_to_print)

# we print all but the last line which is the +---+ separator
for line in lines_to_print[:-1]:
    print(line)

no_running_process = "No running processes found"
if no_running_process in lines[ps_start_idx] or lines[ps_start_idx].startswith("+--"):
    print(lines[-1].strip())
    print("| " + no_running_process + " " * (73 - len(no_running_process)) + "   |")
    # Issue #9, running inside docker and seeing no processes
    if lines[ps_start_idx].startswith("+--"):
        print("| If you're running in a container, you'll only see processes running inside. |")
    print(lines[-1])
    sys.exit()

# Get all process detail
ps_detail, user_detail = nvidia_smi2.get_process_user_detail(ps_start_idx, is_new_format, lines)

# Print process detail
max_pid_length = max(5, max([len(x) for x in ps_detail["pid"]]))
print_format = ("|  %3s %" + str(max_pid_length) + "s %8s   %8s %5s %5s %9s  %-" + str(command_length) + "." + str(command_length) + "s  |")

line = print_format % (
    "GPU", "PID", "USER", "GPU MEM", "%CPU", "%MEM", "TIME", "COMMAND"
)

print("+" + ("-" * (len(line) - 2)) + "+")

print(line)

for i in range(len(ps_detail["pid"])):
    if len(args.user) > 0:
        if ps_detail["user"][i][:7] != args.user[:7]:
            continue
    print(print_format % (
        ps_detail["gpu_num"][i],
        ps_detail["pid"][i],
        ps_detail["user"][i],
        ps_detail["gpu_mem"][i],
        ps_detail["cpu"][i],
        ps_detail["mem"][i],
        ps_detail["time"][i],
        ps_detail["command"][i]
    ))

print("+" + ("-" * (len(line) - 2)) + "+")

# Print user detail
sum_format = ("|  %8s   %14s %11s %11s  |")
line = sum_format % (
    "USER", "TOTAL GPU MEM", "TOTAL %CPU", "TOTAL %MEM"
)
print(line)
print("+" + ("-" * (len(line) - 2)) + "+")
for user in user_detail:
    if len(args.user) > 0:
        if user[:7] != args.user[:7]:
            continue
    print(sum_format % (
        user,
        str(user_detail[user]["total_gpu_mem"]) + "MiB",
        str(round(user_detail[user]["total_cpu"], 1)),
        str(round(user_detail[user]["total_mem"], 1))
    ))

print("+" + ("-" * (len(line) - 2)) + "+")