#!/usr/bin/env python
"""Submit weak scaling benchmarks in a loop.

Examples
--------
./submit_weak --help
./submit_weak -n
./submit_weak -n -d 2 32 32

"""
import os
import numpy as np
from util import get_parser, parser_to_params, init_cluster, submit, Cluster
import fluidsim


parser = get_parser(
    'submit_weak', 'Generic script to submit weak scaling benchmarks', __doc__)
if cluster_type == 'snic':
    parser.set_defaults(n0=1024, n1=1024, nz=240)  # or 1344 1344 672
else:
    parser.set_defaults(n0=64, n1=64, nz=8)

params, params_dim = parser_to_params(parser)
cluster = init_cluster(params, Cluster)
mode = params.mode

log_nb_cores = np.log2(cluster.nb_cores_per_node)
params_dim.nb_cores = np.logspace(1, log_nb_cores, log_nb_cores, base=2, dtype=int)

divisors = np.array(params_dim.nb_cores[::-1]) // 2
multipliers = np.array(params_dim.nodes)
args = parser.parse_args()


def submit_once(nb_nodes, nb_cores_per_node, type_fft, n0, n1, n2=None):
    global params, params_dim, cluster
    if params.dim == 3:
        params_dim.shape = '{} {} {}'.format(n0, n1, n2)
    else:
        params_dim.shape = '{} {}'.format(n0, n1)
    submit(
        params, params_dim, cluster, nb_nodes,
        nb_cores_per_node, type_fft)


if params.dry_run:
    print(params_dim)

if 'seq' in mode:
    for type_fft in params_dim.fft_seq:
        submit(params, params_dim, cluster, 1, 1, type_fft)

if params.dim == 2:
    n0 = args.n0
    n1 = args.n1
    n1_intra = (n1 / divisors).astype(np.int)
    n1_inter = n1 * multipliers
    if 'intra' in mode:
        nb_nodes = 1
        for type_fft in params_dim.fft:
            for nb_cores_per_node, n1 in zip(params_dim.nb_cores, n1_intra):
                submit_once(nb_nodes, nb_cores_per_node, type_fft, n0, n1, n2=None)


    if 'inter' in mode:
        for type_fft in params_dim.fft:
            for nb_nodes, n1 in zip(params_dim.nodes, n1_inter):
                params_dim.shape = '{} {}'.format(n0, n1)
                submit(params, params_dim, cluster, nb_nodes, fft=type_fft)
elif params.dim == 3:
    if 'intra' in mode:
        nb_nodes = 1
        for type_fft in params_dim.fft:
            if 'fftw' in type_fft:
                n0_intra = np.ones_like(divisors) * args.n0
                n1_intra = (args.n1 / divisors).astype(np.int)
            else:
                n0_intra = (args.n0 / divisors).astype(np.int) / 2
                n1_intra = (args.n1 / divisors).astype(np.int) / 2

            for nb_cores_per_node, n0, n1 in zip(params_dim.nb_cores, n0_intra, n1_intra):
                submit_once(nb_nodes, nb_cores_per_node, type_fft, n0, n1, n2=args.n2)

    if 'inter' in mode:
        for type_fft in params_dim.fft:
            if 'fftw' in type_fft:
                n0_intra = np.ones_like(multipliers) * args.n0
                n1_intra = (args.n1 * multipliers).astype(np.int)
            else:
                n0_intra = (args.n0 * multipliers).astype(np.int) * 2
                n1_intra = (args.n1 * multipliers).astype(np.int) * 2

            for nb_cores_per_node, n0, n1 in zip(params_dim.nb_cores, n0_intra, n1_intra):
                submit_once(nb_nodes, nb_cores_per_node, type_fft, n0, n1, n2=args.n2)
