#!/home/tgmaxson/anaconda3/bin/python
from subprocess import Popen, PIPE
from columnar import columnar

class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


def read_sinfo_nodes():
    sinfo = Popen("scontrol show node --detail", shell=True, stdout=PIPE)
    data = sinfo.stdout.read().decode('UTF-8')
    nodes = data.split('NodeName')
    data = []
    drain_data = []
    for node in nodes[1:]:
        lines = node.split("\n")

        # First line is always the name of the node
        node_name = lines[0].split("=")[1].split(" ")[0]

        cpu_alloc, cpu_total, cpu_free = 0, 0, 0
        memory_alloc, memory_total, memory_free = 0, 0, 0
        vendor = "Intel"
        reason = False
        avx2 = False
        gres_used = ""
        gres = ""

        gpu_status = ["", "", "", ""]
        gpu_index = 0

        def key_equal_val(key, line):
            parts = line.split(" ")
            for part in parts:
                sub_parts = part.split("=")
                if sub_parts[0] == key and len(sub_parts) > 1:
                    return "=".join(sub_parts[1:])
            return None

        for line in lines:
            new_cpu_alloc = key_equal_val("CPUAlloc", line)
            if new_cpu_alloc:
                cpu_alloc = int(new_cpu_alloc)

            new_cpu_total = key_equal_val("CPUTot", line)
            if new_cpu_total:
                cpu_total = int(new_cpu_total)

            new_memory_alloc = key_equal_val("AllocMem", line)
            if new_memory_alloc:
                memory_alloc = int(new_memory_alloc)//1024

            new_memory_total = key_equal_val("RealMemory", line)
            if new_memory_total:
                memory_total = int(new_memory_total)//1024

            new_reason = key_equal_val("Reason", line)
            if new_reason:
                reason = line.split("=")[1]

            new_features = key_equal_val("ActiveFeatures", line)
            if new_features and "amd" in new_features:
                vendor = "AMD"
            if new_features and "avx2" in new_features:
                avx2 = True

            new_gres = key_equal_val("Gres", line)
            if new_gres:
                gres = new_gres

            new_gresused = key_equal_val("GresUsed", line)
            if new_gresused:
                gres_used = new_gresused

        #gpu:0
        #gpu:a100:1(IDX:0)
        #gpu:v100:1(IDX:0),gpu:v100-32:1(IDX:1)
        gpus = {}
        for entry in gres.split(","):
            parts = entry.split(":")

            if len(parts) <= 2:
                continue

            name, resource = parts[1], parts[2].split("(")[0]
            gpus[name] = int(resource)

        for entry in gres_used.split(","):
            parts = entry.split(":")

            if len(parts) <= 2:
                continue

            name, resource = parts[1], int(parts[2].split("(")[0])
            free, used = gpus[name] - resource, resource

            for index in range(free):
                gpu_status[gpu_index] = bcolors.OKGREEN + name + bcolors.ENDC
                gpu_index += 1

            for index in range(used):
                gpu_status[gpu_index] = bcolors.FAIL + name + bcolors.ENDC
                gpu_index += 1

        row = [node_name, avx2, cpu_total - cpu_alloc, cpu_alloc, cpu_total, memory_total - memory_alloc, 
               memory_alloc, memory_total, memory_total//cpu_total, *gpu_status]
        if reason:
            drain_data.append(row[0:2] + [reason])
        else:
            data.append(row)
    
    def sort_func(row):
        sorting_vals = []
        sorting_vals.append(row[0].split("-")[-3])
        sorting_vals.append(-row[2])
        sorting_vals.append(int(row[0].split("-")[-2]))
        sorting_vals.append(int(row[0].split("-")[-1]))
        return sorting_vals
    data.sort(key=sort_func)

    table = columnar(data, headers=['\nNode', '\nAVX2', 'CPU Cores\nFree', '\nAlloc', '\nTotal', 'Memory Gb\nFree', '\nAlloc', '\nTotal', '\nTotal/CPU',
                                    'GPUs\n#1', '\n#2', '\n#3', '\n#4'], no_borders=True)
    print(table)

    if len(drain_data) > 0:
        print(bcolors.WARNING+"The following nodes are on DRAIN and must be restarted")
        table = columnar(drain_data, headers=['Node', 'AVX2', 'Reason'], no_borders=True)
        print(table)
        print(bcolors.ENDC)

if __name__ == '__main__':
    read_sinfo_nodes()
