#!python

import py3nvml.nvidia_smi as smi
import xmltodict
from datetime import datetime
import re
import os
import pwd
from subprocess import Popen, PIPE
import argparse
from time import sleep
import sys

parser = argparse.ArgumentParser(description='Print GPU stats')
parser.add_argument('-l', '--loop', action='store', type=int,
                    default=0, help='Loop period')

COL1_WIDTH = 33
COL2_WIDTH = 21
COL3_WIDTH = 21
WIDTH = 77
LONG_FORMAT = False

gpu_format_col1 = '| {:>3} {:3} {:>5} {:>4} {:>11}|'
gpu_format_col2 = ' {:>19} |'
gpu_format_col3 = ' {:>8} {:>10} |'
proc_format = '| {:>3}  {:>11}  {:>5}  {:>11}  {:<27}  {:>5}{:3<} |'


def print_header(driver_version):
    n = 0
    print(datetime.now().strftime('%a %b %d %H:%M:%S %Y'))
    print('+' + '-' * WIDTH + '+')
    print('| ' + '{:<34}'.format('NVIDIA-SMI') +
          'Driver Version: {:<26}'.format(driver_version) + '|')
    print('+' + '-' * COL1_WIDTH +
          '+' + '-' * COL2_WIDTH +
          '+' + '-' * COL3_WIDTH + '+')
    n += 4

    if LONG_FORMAT:
        print('| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |')
        print('| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |')
        n += 2
    else:
        print(gpu_format_col1.format(
            'GPU', 'Fan', 'Temp', 'Perf', 'Pwr:Usage/Cap'), end='')
        print(gpu_format_col2.format(
            'Memory-Usage'), end='')
        print(gpu_format_col3.format(
            'GPU-Util', 'Compute M.'))
        n += 1
    print('+' + '=' * COL1_WIDTH +
          '+' + '=' * COL2_WIDTH +
          '+' + '=' * COL3_WIDTH + '+')
    n += 1
    return n


def print_proc_header():
    print('+' + '-' * COL1_WIDTH +
          '+' + '-' * COL2_WIDTH +
          '+' + '-' * COL3_WIDTH + '+')
    print(' '*79)
    print('+' + '-' * WIDTH + '+')
    print('| Processes:                                                       GPU Memory |')
    #  print('| GPU       Owner   PID    Uptime  Process name                    Usage      |')
    print(proc_format.format('GPU', 'Owner', 'PID', 'Uptime', 'Process Name', 'Usage', '   '))
    print('+' + '=' * WIDTH + '+')
    return 6

def enabled_str(x):
    if x == 'Enabled':
        return 'On'
    else:
        return 'Off'


def get_num_from_str(s):
    # This regular expression matches numbers at the start of strings
    num_re = re.compile(r'^-?[\d.]+')
    return int(float(num_re.match(s).group()) + 0.5)


def print_gpu_info(gpu_info):
    n = 0
    if LONG_FORMAT:
        print('| {:>3}  {:<19} {:<4} |'.format(
                  gpu_info['minor_number'],
                  gpu_info['product_name'],
                  enabled_str(gpu_info['persistence_mode'])) +
              ' {:0>16} {:<3} |'.format(
                  gpu_info['pci']['pci_bus_id'].split("'")[1],
                  enabled_str(gpu_info['display_active'])) +
              ' {:>20} |'.format(
                  'N/A'))
        n += 1

    print(gpu_format_col1.format(
        gpu_info['minor_number'],
        '{}%'.format(get_num_from_str(gpu_info['fan_speed'])),
        '{}C'.format(get_num_from_str(gpu_info['temperature']['gpu_temp'])),
        gpu_info['performance_state'],
        '{:>4}W /{:>4}W '.format(
            get_num_from_str(gpu_info['power_readings']['power_draw']),
            get_num_from_str(gpu_info['power_readings']['power_limit']))),
        end='')
    print(gpu_format_col2.format(
        '{:>5}MiB / {:>5}MiB'.format(
            get_num_from_str(gpu_info['fb_memory_usage']['used']),
            get_num_from_str(gpu_info['fb_memory_usage']['total']))),
        end='')
    print(gpu_format_col3.format(
        '{}%'.format(get_num_from_str(gpu_info['utilization']['gpu_util'])),
        gpu_info['compute_mode']))
    n += 1

    return n


def cut_proc_name(name, maxlen):
    if len(name) > maxlen:
        return '...' + name[-maxlen+3:]
    else:
        return name


def get_uname_pid(pid):
    try:
        # the /proc/PID is owned by process creator
        proc_stat_file = os.stat("/proc/%d" % pid)
        # # get UID via stat call
        uid = proc_stat_file.st_uid
        # # look up the username from uid
        username = pwd.getpwuid(uid)[0]
    except:
        username = '???'
    return username


def get_uptime(pid):
    try:
        sess = Popen(['ps', '-q', str(pid), '-o', 'etime='], stdout=PIPE, stderr=PIPE)
        stdout, stderr = sess.communicate()
        time = stdout.decode('utf-8').strip()
    except:
        time = '?'
    return time


def print_process_info(gpu_num, proc_info):
    print(proc_format.format(
        gpu_num,
        get_uname_pid(int(proc_info['pid'])),
        proc_info['pid'],
        get_uptime(proc_info['pid']),
        cut_proc_name(proc_info['process_name'], 27),
        get_num_from_str(proc_info['used_memory']),
        'MiB'))
    return 1


def main():
    num_lines = 0
    xml = smi.XmlDeviceQuery()
    d = xmltodict.parse(xml)['nvidia_smi_log']
    header_lines = print_header(d['driver_version'])

    gpu_lines = 0
    for gpu in d['gpu']:
        gpu_lines += print_gpu_info(gpu)

    proc_header_lines = print_proc_header()

    proc_lines = 0
    for gpu in d['gpu']:
        if gpu['processes'] is not None:
            procs = gpu['processes']
            if isinstance(procs['process_info'], list):
                for p in procs['process_info']:
                    proc_lines += print_process_info(gpu['minor_number'], p)
            elif isinstance(procs['process_info'], dict):
                proc_lines += print_process_info(gpu['minor_number'], procs['process_info'])
    print('+' + '-' * WIDTH + '+')

    #  print('header lines: {}'.format(header_lines))
    #  print('gpu lines: {}'.format(gpu_lines))
    #  print('proc header lines: {}'.format(proc_header_lines))
    #  print('proc lines: {}'.format(proc_lines))
    return header_lines + gpu_lines + proc_header_lines + proc_lines + 1


if __name__ == '__main__':
    args = parser.parse_args()
    print_lines = main()
    if args.loop > 0:
        while True:
            sleep(args.loop)
            sys.stdout.write("\033[F" * print_lines)
            print_lines_new = main()
            if print_lines_new < print_lines:
                sys.stdout.write((' '*79+'\n')*(print_lines - print_lines_new))
                sys.stdout.write("\033[F" * (print_lines - print_lines_new))
            print_lines = print_lines_new
