#!/usr/bin/python3

import apprentice as app
import numpy as np


if __name__ == "__main__":

    import optparse, os, sys, h5py
    op = optparse.OptionParser(usage=__doc__)
    op.add_option("-v", "--debug", dest="DEBUG", action="store_true", default=False, help="Turn on some debug messages")
    op.add_option("-o", dest="OUTPUT", default="approx.json", help="Output filename (default: %default)")
    op.add_option("-w", dest="WEIGHTS", default=None, help="Obervable file (default: %default)")
    op.add_option("-s", dest="SEED", type=int, default=1234, help="Random seed (default: %default)")
    op.add_option("--order", dest="ORDER", default=None, help="Polynomial orders of numerator and denominator, comma separated (default: %default)")
    op.add_option("--mode", dest="MODE", default="sip", help="Base algorithm  --- la |sip|lasip --- (default: %default)")
    op.add_option("--errs", dest="ERRS", action='store_true', default=False, help="Build approximations for errors, (default is for values)")
    op.add_option("--log", dest="ISLOG", action='store_true', default=False, help="input data is logarithmic --- affects how we filter (default: %default)")
    op.add_option("--ftol", dest="FTOL", type=float, default=1e-9, help="ftol for SLSQP (default: %default)")
    op.add_option("--pname", dest="PNAME", default="params.dat", help="Name of the params file to be found in each run directory (default: %default)")
    op.add_option("--itslsqp", dest="ITSLSQP", type=int, default=200, help="maxiter for SLSQP (default: %default)")
    op.add_option("--msg", dest="MSGEVERY", default=5, type=int, help="Verbosity of progress (default: %default)")
    op.add_option("-t", "--testpoles", dest="TESTPOLES", type=int, default=10, help="Number of multistarts for pole detection (default: %default)")
    op.add_option("--convert", dest="CONVERTINPUT", default=None, help="Option to store input data as hdf, needs argument (default: %default)")
    opts, args = op.parse_args()

    rank=0
    size=1
    try:
        from mpi4py import MPI
        comm = MPI.COMM_WORLD
        size = comm.Get_size()
        rank = comm.Get_rank()
    except Exception as e:
        print("Exception when trying to import mpi4py:", e)
        comm = None
        pass

    if opts.MODE not in ["la", "sip", "lasip"]:
        print("Error: specified mode {} not known".format(opts.MODE))
        sys.exit(1)

    if len(args) == 0:
        print("No input specified, exiting")
        sys.exit(1)

    if not os.path.exists(args[0]):
        print("Input '{}' not found, exiting.".format(args[0]))
        sys.exit(1)

    # Prevent overwriting of input data
    assert(args[0]!=opts.OUTPUT)

    # Data loading and distribution of work
    if os.path.isfile(args[0]):
        DATA, binids, pnames, rankIdx, xmin, xmax = app.io.readInputDataH5(args[0], opts.WEIGHTS)
    elif os.path.isdir(args[0]):
        # YODA directory parsing here
        DATA, binids, pnames, rankIdx, xmin, xmax = app.io.readInputDataYODA(args, opts.PNAME, opts.WEIGHTS, storeAsH5=opts.CONVERTINPUT)
    else:
        print("{} neither directory nor file, exiting".format(args[0]))
        exit(1)

    if size>1: comm.barrier() # TODO why do we need a barrier here?
    print("[{}] will proceed to calculate approximations for {} objects".format(rank, len(DATA)))
    sys.stdout.flush()

    apps = []
    M, N = [int(x) for x in opts.ORDER.split(",")]

    import time
    t4   = time.time()
    import datetime
    binedges = {}
    dapps = {}
    for num, (X, Y, E) in  enumerate(DATA):
        thisBinId = binids[num]

        if rank==0 or rank==size-1:
            if ((num+1)%opts.MSGEVERY ==0):
                now = time.time()
                tel = now - t4
                ttg = tel*(len(DATA)-num)/(num+1)
                eta = now + ttg
                eta = datetime.datetime.fromtimestamp(now + ttg)
                sys.stdout.write("{}[{}] {}/{} (elapsed: {:.1f}s, to go: {:.1f}s, ETA: {})\r".format(80*" " if rank>0 else "", rank, num+1, len(DATA), tel, ttg, eta.strftime('%Y-%m-%d %H:%M:%S')) ,)
                sys.stdout.flush()
        # USE = np.where( (_Y>0) ) if opts.ISLOG else np.where( (_E>=0) )
        # from IPython import embed
        # embed()
        # X = _X[USE, :]
        # Y =  np.log10(_Y[USE]) if opts.ISLOG else _Y[USE]
        # E =  np.log10(_E[USE]) if opts.ISLOG else _E[USE]

        if len(X) == 0:
            print("No data to calculate approximation for {} --- skipping".format(binids[rankIdx[num]]))
            import sys
            sys.stdout.flush()
            continue
        if len(X) < app.tools.numCoeffsRapp(len(X[0]),order=(M,N)):
            print("Not enough data ({} vs {}) to calculate approximation for {} --- skipping".format(len(X), app.tools.numCoeffsRapp(len(X[0]), order=(M,N)), binids[rankIdx[num]]))
            import sys
            sys.stdout.flush()
            continue

        if opts.ERRS:
            temp,  hasPole = app.tools.calcApprox(X, E, (M,N), pnames, opts.MODE, debug=opts.DEBUG, testforPoles=opts.TESTPOLES, ftol=opts.FTOL, itslsqp=opts.ITSLSQP)
            vmin = np.min(E)
            vmax = np.max(E)
        else:
            temp,  hasPole = app.tools.calcApprox(X, Y, (M,N), pnames, opts.MODE, debug=opts.DEBUG, testforPoles=opts.TESTPOLES, ftol=opts.FTOL, itslsqp=opts.ITSLSQP)
            vmin = np.min(Y)
            vmax = np.max(Y)
        if temp is None:
            print("Unable to calculate value approximation for {} --- skipping".format(thisBinId))
            import sys
            sys.stdout.flush()
            continue
        else:
            if hasPole:
                print("Warning: pole detected in {}".format(thisBinId))
                import sys
                sys.stdout.flush()
        temp._vmin = vmin
        temp._vmax = vmax
        temp._xmin = xmin[num]
        temp._xmax = xmax[num]
        apps.append((temp, xmin[num], xmax[num]))
        dapps[thisBinId]= temp.asDict
        binedges[thisBinId] = (xmin[num], xmax[num])


    if size>1:
        DAPPS = comm.gather(dapps, root=0)
        DEDGE = comm.gather(binedges, root=0)
    else:
        DAPPS=[dapps]
        DEDGE=[binedges]


    t5   = time.time()
    if rank==0:
        print()
        print("Approximation calculation took {} seconds".format(t5-t4))
        sys.stdout.flush()

        # Store in JSON
        from collections import OrderedDict
        JD = OrderedDict()

        a, e = {}, {}
        for apps in DAPPS:
            a.update(apps)

        for edges in DEDGE:
            e.update(edges)

        xmin, xmax = [], []
        for k in a.keys():
            xmin.append(e[k][0])
            xmax.append(e[k][1])
            JD[k] = a[k]
        # TODO delete __xmin __xmax
        JD["__xmin"]=xmin
        JD["__xmax"]=xmax

        import json
        with open(opts.OUTPUT, "w") as f: json.dump(JD, f, indent=4)

        print("Done --- {} approximations written to {}".format(len(JD), opts.OUTPUT))

    exit(0)
