#!/usr/bin/env python3

import apprentice as app
import numpy as np
import time

def mkPlotsMinimum(TO, x0, y0=None, prefix=""):
    temp = TO._fixIdx
    import pylab
    pnames = np.array(TO.pnames)[TO._freeIdx]

    XX = [TO.lineScan(x0,i) for i in range(len(x0))]
    YY = []
    for i, X in enumerate(XX):
        Y   =[TO.objective(x) for x in X]
        YY.append(Y)
        Y   =[TO.objective(x, unbiased=True) for x in X]
        YY.append(Y)
    ymax=np.max(np.array(YY))
    ymin=np.min(np.array(YY))


    for i in range(len(x0)):
        pylab.clf()
        X=TO.lineScan(x0,i)
        Y   =[TO.objective(x) for x in X]
        Yunb=[TO.objective(x, unbiased=True) for x in X]
        pylab.axvline(x0[i], label="x0=%.5f"%x0[i], color="k")
        y0=TO.objective(x0, unbiased=True)
        # pylab.axhline(y0, label="unbiased y0=%.2f"%y0, color="k")
        pylab.plot(X[:,i], Y, label="objective")
        # pylab.plot(X[:,i], Yunb, linestyle="dashed", label="unbiased")
        #TODO just normalized ratio of bias vs unbiased
        pylab.ylabel("objective")
        pylab.xlabel(pnames[i])
        pylab.ylim((0.9*ymin, 1.1*ymax))
        # if abs(ymin-ymax)>1000:
            # pylab.yscale("log")
        pylab.legend()
        pylab.tight_layout()
        pylab.savefig(prefix+"valley_{}.pdf".format(i))

def mkPlotsCorrelation(TO, x0, prefix=""):
    H=TO.hessian(x0)
    COV = np.linalg.inv(H)
    COR = np.zeros_like(COV)
    nd = len(x0)
    for i in range(nd):
        for j in range(nd):
            COR[i,j] = COV[i,j]/np.sqrt(COV[i,i]*COV[j,j])

    mask =  np.tri(COR.shape[0], k=0)
    A = np.ma.array(COR, mask=mask)
    import pylab
    pylab.clf()
    bb = pylab.imshow(A, vmin=-1, vmax=1, cmap="RdBu")
    locs, labels = pylab.yticks()
    pylab.yticks([i for i in range(nd)], TO.pnames, rotation=00)
    locs, labels = pylab.xticks()
    pylab.xticks([i for i in range(nd)], TO.pnames, rotation=90)
    cbar = pylab.colorbar(bb, extend='both')

    pylab.tight_layout()
    pylab.savefig(prefix+"corr.pdf")

import optparse, os, sys
op = optparse.OptionParser(usage=__doc__)
op.add_option("-v", "--debug", dest="DEBUG", action="store_true", default=False, help="Turn on some debug messages")
op.add_option("-o", dest="OUTDIR", default="tune", help="Output directory (default: %default)")
op.add_option("-e", "--errorapprox", dest="ERRAPP", default=None, help="Approximations of bin uncertainties (default: %default)")
op.add_option("-s", "--survey", dest="SURVEY", default=1, type=int, help="Size of survey when determining start point (default: %default)")
op.add_option("-r", "--restart", dest="RESTART", default=1, type=int, help="Minimiser restarts (default: %default)")
op.add_option("--seed", dest="SEED", default=1234, type=int, help="The base random seed (default: %default)")
op.add_option("--msp", dest="MSP", default=None, help="Manual startpoint, comma separated string (default: %default)")
op.add_option("-a", "--algorithm", dest="ALG", default="tnc", help="The minimisation algrithm tnc, ncg, lbfgsb, trust (default: %default)")
op.add_option("-l", "--limits", dest="LIMITS", default=None, help="Parameter file with limits and fixed parameters (default: %default)")
op.add_option("-f", dest="FORCE", default=False, action = 'store_true', help="Allow overwriting output directory (default: %default)")
op.add_option("-p", "--plotvalley", dest="PLOTVALLEY", default=False, action = 'store_true', help="Allow overwriting output directory (default: %default)")
op.add_option("--tol", dest="TOL", default=1e-6, type=float, help="Tolerance for scipy optimize minimize (default: %default)")
op.add_option("--no-check", dest="NOCHECK", default=False, action="store_true", help="Don't check for sadlepoints (default: %default)")
opts, args = op.parse_args()


if opts.ALG not in ["tnc", "ncg", "lbfgsb" ,"trust"]:
    raise Exception("Minimisation algorithm {} not implemented, should be tnc, ncg, lbfgsb or trust, exiting".format(opts.ALG))

#TODO add protections and checks if files exist etc
if not os.path.exists(opts.OUTDIR): os.makedirs(opts.OUTDIR)

WFILE = args[0]
DATA  = args[1]
APP   = args[2]

np.random.seed(opts.SEED)

GOF = app.appset.TuningObjective2(WFILE, DATA, APP, f_errors=opts.ERRAPP, debug=opts.DEBUG)
if opts.LIMITS is not None: GOF.setLimitsAndFixed(opts.LIMITS)

if opts.MSP is not None:
    x0 = [float(x) for x in opts.MSP.split(",")]
    GOF.setManualStartPoint(x0)

import time
t0=time.time()
res = GOF.minimize(opts.SURVEY, opts.RESTART, method=opts.ALG, tol=opts.TOL, saddlePointCheck=not opts.NOCHECK)

t1=time.time()
if opts.DEBUG: print(res)

chi2 = GOF.objective(res.x, unbiased=True)
ndf = GOF.ndf

meta  = "# Objective value at best fit point: %.2f (%.2f without weights)\n"%(res.fun, chi2)
meta += "# Degrees of freedom: {}\n".format(ndf)
meta += "# phi2/ndf: %.3f\n"%(chi2/ndf)
meta += "# Minimisation took {} seconds\n".format(t1-t0)
meta += "# Command line: {}\n".format(" ".join(sys.argv))
meta += "# Best fit point:"

print(meta)
print(GOF.printParams(res.x))

outcommon = "{}_{}_{}".format(opts.ALG, opts.SURVEY, opts.RESTART, opts.SEED)

try:
    import yoda
    app.tools.prediction2YODA(APP, GOF.mkPoint(res.x), opts.OUTDIR+"/predictions_{}.yoda".format(outcommon), opts.ERRAPP)
except ImportError:
    pass

mkPlotsCorrelation(GOF, res.x, opts.OUTDIR+"/{}_".format(outcommon))
GOF.writeResult(res.x, os.path.join(opts.OUTDIR, "minimum_{}.txt".format(outcommon)), meta=meta)
if opts.PLOTVALLEY:
    plotout = os.path.join(opts.OUTDIR, "valleys_{}".format(outcommon))
    if not os.path.exists(plotout): os.makedirs(plotout)
    mkPlotsMinimum(GOF, res.x, prefix=plotout+"/")


import shutil
shutil.copy(WFILE, os.path.join(opts.OUTDIR, "weights_{}.txt".format(outcommon)))

print("Output written to directory {}.".format(opts.OUTDIR))
