#!/project2/meren/VIRTUAL-ENVS/anvio-master/bin/python
# -*- coding: utf-8

import os
import sys
import random
import string
import tempfile
import argparse
import subprocess
import configparser

from colored import style, fore

SBATCH_template = """#!/bin/bash
#SBATCH --job-name={job_name}_{seed}
#SBATCH --output={job_name}_{seed}.out
#SBATCH --error={job_name}_{seed}.err
#SBATCH --partition={partition}
#SBATCH --nodes={num_nodes}
#SBATCH --ntasks-per-node={num_tasks_per_node}
#SBATCH --time={allotted_time}
#SBATCH --mem-per-cpu={mem_per_cpu}

{command}
"""

out_of_box_defaults = {
    'job_name': 'clusterize',
    'partition': None,
    'num_nodes': 1,
    'num_tasks_per_node': 1,
    'allotted_time': '15:00:00',
    'mem_per_cpu': 10000,
}

class Clusterize(object):
    def __init__(self, args):
        self.job_seed = self.get_job_seed()

        self.config_file_path = os.path.join(os.path.expanduser('~'), '.clusterize_config')
        self.config_file_exists = True if os.path.exists(self.config_file_path) else False
        self.user_defaults = {}

        A = lambda x: args.__dict__.get(x, None) if args.__dict__.get(x, None) is not None else self.get_default(x)
        self.params = {
            'command': A('command'),
            'job_name': A('job_name'),
            'partition': A('partition'),
            'num_nodes': A('num_nodes'),
            'num_tasks_per_node': A('num_tasks_per_node'),
            'allotted_time': A('allotted_time'),
            'mem_per_cpu': A('mem_per_cpu'),
            'seed': self.job_seed,
        }

        self.job_out = '{}_{}.out'.format(self.params['job_name'], self.job_seed)
        self.job_err = '{}_{}.err'.format(self.params['job_name'], self.job_seed)

        self.check_params()

        self.sbatch_filepath = self.gen_sbatch_file()

        self.run_job()


    def get_job_seed(self, length=10):
        letters = string.ascii_letters
        return ''.join(random.choice(letters) for i in range(length))


    def touch(self, path):
        if os.path.exists(path):
            pass
        else:
            open(path, 'a').close()


    def run_job(self):
        cmd = 'sbatch {}'.format(self.sbatch_filepath)

        try:
            output = subprocess.check_output(cmd, shell=True, universal_newlines=True)
            self.touch(self.job_out)
            self.touch(self.job_err)
            print(fore.GREEN + str(output.strip()))
            print('    output_file: {}'.format(self.job_out))
            print('    error_file: {}'.format(self.job_err) + style.RESET)
        except subprocess.CalledProcessError as e:
            print(fore.RED + 'ClusterizeError' + style.RESET)
        finally:
            os.remove(self.sbatch_filepath)


    def check_params(self):
        for param, param_value in self.params.items():
            if param_value is None:
                print(fore.RED + \
                      '{} is set to None. Either provide the flag {} or set default with <FIXME>'.format(param, '--' + param.replace('_','-')) + \
                      style.RESET)
                sys.exit()

        try:
            self.params['num_nodes'] = int(self.params['num_nodes'])
        except:
            print(fore.RED + 'num_nodes must be an integer. its currently `{}`'.format(self.params['num_nodes']) + style.RESET)
            sys.exit()

        try:
            self.params['num_tasks_per_node'] = int(self.params['num_tasks_per_node'])
        except:
            print(fore.RED + 'num_tasks_per_node must be an integer. its currently `{}`'.format(self.params['num_tasks_per_node']) + style.RESET)
            sys.exit()

        try:
            self.params['mem_per_cpu'] = int(self.params['mem_per_cpu'])
        except:
            print(fore.RED + 'mem_per_cpu must be an integer. its currently `{}`'.format(self.params['mem_per_cpu']) + style.RESET)
            sys.exit()


    def gen_sbatch_file(self, filepath=None):
        if not filepath:
            f = tempfile.NamedTemporaryFile(delete = False, prefix = 'clusterize_' + self.job_seed + '_')
            filepath = f.name
            f.close()

        file_as_str = SBATCH_template.format(**self.params)
        with open(filepath, 'w') as f:
            f.write(file_as_str)

        return filepath


    def get_default(self, param_name):
        if self.config_file_exists and not self.user_defaults:
            config = configparser.ConfigParser()
            config.read(self.config_file_path)
            for k, v in config['CLUSTERIZE_DEFAULTS'].items():
                self.user_defaults[k] = v
        else:
            pass

        return self.user_defaults.get(param_name) or out_of_box_defaults.get(param_name)


def main(args):
    c = Clusterize(args)


def get_default_config():
    path = os.path.join(os.path.expanduser('~'), '.clusterize_config')

    if os.path.exists(path):
        print(fore.RED + '{} already exists. delete it first if you want to create a new default config'.format(path) + style.RESET)
        sys.exit()

    config = configparser.ConfigParser()
    config['CLUSTERIZE_DEFAULTS'] = {k: str(v) for k, v in out_of_box_defaults.items()}
    with open(path, 'w') as configfile:
        config.write(configfile)
    print(fore.GREEN + 'default config file written to {}'.format(path))


if __name__ == '__main__':
    ap = argparse.ArgumentParser()

    groupP = ap.add_argument_group('THE COMMAND')
    groupP.add_argument('command', type=str, help='Your bash command that will be submitted as a SLURM job.\
                                                   It must be contained in double quotes, i.e. "<your command>".\
                                                   If the command itself contains DOUBLE-quotes, prefix each\
                                                   if them with a backslash, e.g. "echo \\"$(HOME)\\" > test"')

    groupA = ap.add_argument_group('CONFIG', "Some of these options are required, and can be set permanently by\
                                              creating a config file. See 'DEFAULT CONFIG'. Any parameter set\
                                              here overwrites that found in the config file")

    groupA.add_argument('-p','--partition', type=str, help='Which partition of the cluster are you using?')
    groupA.add_argument('-o','--job-name', type=str, help='Give a useful name to your job. It doesn\'t have to be unique,\
                                                           as a unique ID is appended to whatever you pick.\
                                                           This name will show up in the SLURM queue, and will\
                                                           be the prefix for the .out and .err files of the job.\
                                                           as an example, if the job name is `job`, it may show\
                                                           in the queue as `job_hcwknUCbSr`, and the outputs will\
                                                           be `job_hcwknUCbSra.out` and `job_hcwknUCbSra.err`.\
                                                           The default is simply {}'.format(out_of_box_defaults['job_name']))
    groupA.add_argument('-N','--num-nodes', type=int, help='How many nodes you want to use? default is {}'.format(out_of_box_defaults['num_nodes']))
    groupA.add_argument('-n','--num-tasks-per-node', type=int, help='How many nodes you want to use? default {}'.format(out_of_box_defaults['num_tasks_per_node']))
    groupA.add_argument('-t','--allotted-time', help='After this amount of time, the process will be killed :( The \
                                                      default is {}, which could be higher than your cluster allows.\
                                                      One acceptable time format is HH:MM:SS, i.e. 15 hours would\
                                                      be: `15:00:00`'.format(out_of_box_defaults['allotted_time']))
    groupA.add_argument('-M','--mem-per-cpu', type=int, help='How much memory in MB should be allotted per cpu? Default\
                                                              is {}'.format(out_of_box_defaults['mem_per_cpu']))

    groupB = ap.add_argument_group('DEFAULT CONFIG', "Create a default config file located at ~/.clusterize_config")
    groupB.add_argument('--gen-config-file', action='store_true', help="This generates a template that clusterize will read to get its parameters.\
                                                      All parameters in this file will be overwritten when parameters are provided\
                                                      explicitly to clusterize through the command line.")

    args = ap.parse_args()

    if args.gen_config_file:
        get_default_config()
    else:
        main(args)

