#!/usr/bin/env python

import os
import argparse
import mlprocessors as mlpr
from mountaintools import client as mt
import mtlogging

# This can be important for some of the jobs in certain situations
os.environ['DISPLAY'] = ''


@mtlogging.log(root=True)
def main():
    parser = argparse.ArgumentParser(
        description='Listen for batches as a compute resource')
    parser.add_argument(
        'compute_resource_name', help='Name of the compute resource')
    parser.add_argument('--collection', help='Name of collection (e.g., spikeforest)',
                        required=False, default=None)
    parser.add_argument(
        '--kachery_name', help='Name of share id', required=False, default=None)
    parser.add_argument(
        '--srun_opts', help='Options for slurm\'s srun', required=False, default=None)
    parser.add_argument(
        '--parallel',
        help='Number of jobs to run in parallel (should not be run in combination with srun_opts)',
        required=False, default=None)
    parser.add_argument('--login', help='Whether to login',
                        action='store_true')
    parser.add_argument('--allow_uncontainerized',
                        help='Allow jobs to run outside of containers (potentially dangerous)',
                        action='store_true')

    args = parser.parse_args()

    print('compute resource name: {}'.format(args.compute_resource_name))
    print('collection = {}'.format(args.collection))
    print('kachery_name = {}'.format(args.kachery_name))
    print('srun opts = {}'.format(args.srun_opts))
    print('num parallel = {}'.format(args.parallel))
    print('allow uncontainerized = {}'.format(args.allow_uncontainerized))

    if args.login:
        mt.login()

    if args.kachery_name:
        mt.configDownloadFrom(args.kachery_name)

    server = mlpr.ComputeResourceServer(
        resource_name=args.compute_resource_name,
        collection=args.collection,
        kachery_name=args.kachery_name  # ,
        # allow_uncontainerized=args.allow_uncontainerized
    )
    if args.srun_opts is not None:
        server.setSrunOptsString(args.srun_opts)
    if args.parallel is not None:
        server.setNumParallel(int(args.parallel))
    server.start()


if __name__ == "__main__":
    main()
