#!/usr/bin/env python3

"""
Sample from a parameter space and optionally instantiate templates.

Usage:

    app-sample params.ranges -n 200 -o myscan  [runcardtemplates..]

cat params.ranges
parA  0  1
parB  2  4


The runcard templates if provided automatically fill in placeholders
such as:

    SpaceShower:pT0Ref = {parA}
    SpaceShower:alphaSvalue = {parB}


"""

import numpy as np
from scipy.stats import qmc

def writeParams(P, templates, outdir, fname="params.dat"):
    from os.path import join, exists
    for num, p in enumerate(P):
        npad = "{}".format(num).zfill(1+int(np.ceil(np.log10(len(P)))))
        outd = join(outdir, npad)
        outf = join(outd, fname)

        if not exists(outd):
            import os
            os.makedirs(outd)

        with open(outf, "w") as pf:
            for k, v in list(p.items()):
                pf.write("{name} {val:e}\n".format(name=k, val=v))

        ## Instantiate template(s)
        p["N"] = npad #< Add the run number *after* writing out the params file
        for tbasename, tmpl in templates.items():
            txt = tmpl.format(**p)
            tname = join(outd, tbasename)
            with open(tname, "w") as tf:
                tf.write(txt)

def sample(boxdef, npoints, method, seed=None):
    is_json=True;
    import json
    with open(boxdef) as f:
        c = f.read(1)
        if c!="{": is_json=False

    if is_json:
        with open(boxdef) as f:
            import json
            B = json.load(f)
    else:
        with open(boxdef) as f:
            _=[l.strip().split() for l in f]
            B = {x[0]:[float(x[1]), float(x[2])] for x in _}

    # Alphabetical sort for ranges
    porder = sorted(B.keys())
    xmin = [B[x][0] for x in porder]
    xmax = [B[x][1] for x in porder]

    ndim = len(xmin)

    # Randomly sampled points
    if method== "uniform":
        if seed is not None:
            np.random.seed(opts.SEED)
        points = np.random.uniform(low=xmin, high=xmax,size=(npoints, ndim))
    elif method== "lhs":
        sampler = qmc.LatinHypercube(ndim, seed=seed)
        sample = sampler.random(n=npoints)
        points = qmc.scale(sample, xmin, xmax)
    elif method== "sobol":
        sampler = qmc.Sobol(ndim, seed=seed)
        sample = sampler.random(n=npoints)
        points = qmc.scale(sample, xmin, xmax)
    else:
        raise Exception("Requested sampling method {} not implemented".format(method))

    # As dictionaries
    return [ dict(zip(porder, x)) for x in points]


if __name__=="__main__":
    import optparse, os, sys
    op = optparse.OptionParser(usage=__doc__)
    op.add_option("-v", "--debug", dest="DEBUG", action="store_true", default=False, help="Turn on some debug messages")
    op.add_option("-o", dest="OUTDIR", default="tune", help="Output directory (default: %default)")
    op.add_option("-n", "--npoints", dest="NPOINTS", default=100, type=int, help="Number of points to sample (default: %default)")
    op.add_option("-s", "--seed", dest="SEED", default=None, type=int, help="The base random seed (default: %default)")
    op.add_option("-m", "--method", dest="METHOD", default="lhs", help="The sampling method (default: %default, options: uniform, sobol, lhs)")
    opts, args = op.parse_args()

    if not opts.METHOD in "uniform sobol lhs".split():
        print("Error, the selected sampling method {} is not known. Exiting...".format(opts.METHOD))
        sys.exit(1)


    if len(args) == 0:
        print("Error, no command line arguments given. Provide at least the text file with parameter ranges. Exiting...")
        sys.exit(1)

    for arg in args:
        if not os.path.exists(arg):
            print("Error, the file {} does not exist. Exiting...".format(arg))

    # Sample n points from range given in file
    PP = sample(args[0], opts.NPOINTS, opts.METHOD, opts.SEED)

    TEMPLATES={}
    for templatefile in args[1:]:
        tname = os.path.basename(templatefile)
        with open(templatefile, "r") as f:
            TEMPLATES[tname] = f.read()

    writeParams(PP, TEMPLATES, opts.OUTDIR)
    print("Done. Output written to {}.".format(opts.OUTDIR))
