#!/usr/bin/env python

#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps.  If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Author: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#

from __future__ import print_function

import sys
import getopt
import os
import os.path
import dill as pickle
import time
import datetime
import logging
import numpy as np
import scipy.stats as stats
import TransportMaps as TM
import TransportMaps.CLI as TMCLI
import TransportMaps.Maps as MAPS
import TransportMaps.Distributions as DIST

sys.path.append(os.getcwd())

# Data storage object
stg = type('', (), {})()

def usage():
    usage_str = """
Usage: tmap-laplace [-h -f -I] --dist=DIST --output=OUTPUT
  [--tol=TOL --ders=DERS --fungrad
    --hessact --hessact-rnd-eps=EPS --hessact-pow-n=N --hessact-ovsamp=N
    --log=LOG]
"""
    print(usage_str)

def full_usage():
    usage()

def description():
    docs_ders_str = \
        '  --ders=DERS             derivatives to be used in the optimization\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_DERS,'                          ')
    docs_log_str = \
        '  --log=LOG               log level (default=30). Uses package logging.\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_LOGGING,'                          ')
    docs_str = """DESCRIPTION
Given a file (--dist) storing the target distribution, generate the linear map
corresponding to the Laplace approximation of it.
All files involved are stored and loaded using the python package dill.

OPTIONS
  --dist=DIST             file containing the target distribution 
  --output=OUTPUT         output file containing the linear transport map,
                          the base distribution (standard normal), the target distribution
                          and all the additional parameters used for the
                          construction
  --tol=TOL               optimization tolerance (default: 1e-4)
""" + docs_ders_str + docs_log_str + \
"""  --fungrad              whether the distributions provide a method to compute
                          the log pdf and its gradient at the same time
  --hessact               whether the distribution provides a method to compute
                          the action of the Hessian of the log pdf on a vector
  --hessact-rnd-eps=EPS   tolerance to be used in the pursue of a randomized
                          low-rank approximation of the prior preconditioned
                          Hessian of the log-likelihood
  --hessact-pow-n=N       number of power iterations to be used in the pursue of a randomized
                          low-rank approximation of the prior preconditioned
                          Hessian of the log-likelihood
  --hessact-ovsamp=N      oversampling to be used in the pursue of a randomized
                          low-rank approximation of the prior preconditioned
                          Hessian of the log-likelihood
  --overwrite             overwrite file if it exists
  --reload                reload file if it exists
  -I                      enter interactive mode after finishing
  -h                      print this help
"""
    print(docs_str)

def full_doc():
    full_usage()
    description()

##################### INPUT PARSING #####################
argv = sys.argv[1:]
INTERACTIVE = False
# I/O
OUT_FNAME = None
DIST_FNAME = None
# Solver options
stg.TOL = 1e-4
stg.DERS = 2
stg.FUNGRAD = False
stg.HESSACT = False
stg.HESSACT_RND_EPS = 1e-5
stg.HESSACT_POW_N = 0
stg.HESSACT_OVSAMP = 10
# Overwriting/reloading
OVERWRITE = False
RELOAD = False
# Logging
LOGGING_LEVEL = 30 # Warnings
try:
    opts, args = getopt.getopt(argv,"hI",[
        "output=", "dist=",
        "tol=", "ders=", "fungrad",
        "hessact", "hessact-rnd-eps=", "hessact-pow-n=", "hessact-ovsamp=",
        # Overwriting and reloading
        "overwrite", "reload",
        # Logging
        "log="])
except getopt.GetoptError:
    full_usage()
    sys.exit(2)
for opt, arg in opts:
    if opt == '-h':
        full_doc()
        sys.exit()
    # I/O
    elif opt == "--output":
        OUT_FNAME = arg
    elif opt == "--dist":
        DIST_FNAME = arg
    # Solver options
    elif opt == '--tol':
        stg.TOL = float(arg)
    elif opt == '--ders':
        stg.DERS = int(arg)
    elif opt == '--fungrad':
        stg.FUNGRAD = True
    elif opt == '--hessact':
        stg.HESSACT = True
    elif opt == '--hessact-rnd-eps':
        stg.HESSACT_RND_EPS = float(arg)
    elif opt == '--hessact-pow-n':
        stg.HESSACT_POW_N = int(arg)
    elif opt == '--hessact-ovsamp':
        stg.HESSACT_OVSAMP = int(arg)
    # Overwriting/reloading
    elif opt == '--overwrite':
        OVERWRITE = True
    elif opt == '--reload':
        RELOAD = True
    # Logging
    elif opt == '--log':
        LOGGING_LEVEL = int(arg)
    # Interactive
    elif opt == "-I":
        INTERACTIVE = True
    else:
        raise RuntimeError("Input option %s not recognized." % opt)

def tstamp_print(msg, *args, **kwargs):
    tstamp = datetime.datetime.fromtimestamp(
        time.time()
    ).strftime('%Y-%m-%d %H:%M:%S')
    print(tstamp + " " + msg, *args, **kwargs)
        
if None in [OUT_FNAME]:
    full_usage()
    tstamp_print("ERROR: Option and --output must be specified")
    sys.exit(3)

if not OVERWRITE and not RELOAD and os.path.exists(OUT_FNAME):
    sel = ''
    while sel not in ['o', 'r', 'q']:
        sel = input("The file %s already exists. " % OUT_FNAME + \
                    "Do you want to overwrite (o), reload (r) or quit (q)? [o/r/q] ")
    if sel == 'o':
        while sel not in ['y', 'n']:
            sel = input("Please, confirm that you want overwrite. [y/n]")
        if sel == 'n':
            tstamp_print("Terminating.")
            sys.exit(0)
    elif sel == 'r':
        RELOAD = True
    else:
        tstamp_print("Terminating.")
        sys.exit(0)
    
try:
    TM.setLogLevel(LOGGING_LEVEL)

    if RELOAD:
        with open(OUT_FNAME, 'rb') as istr:
            stg = pickle.load(istr)
        x0 = stg.tmap.constantTerm
    else:
        with open(DIST_FNAME,'rb') as in_stream:
            stg.target_distribution = pickle.load(in_stream)
        x0 = None

    dim = stg.target_distribution.dim
    stg.base_distribution = DIST.StandardNormalDistribution(dim)
    
    ################ Laplace approximation #############
    laplace_approx = TM.laplace_approximation(
        stg.target_distribution,
        x0=x0, tol=stg.TOL, ders=stg.DERS,
        fungrad=stg.FUNGRAD,
        hessact=stg.HESSACT, hessact_rnd_eps=stg.HESSACT_RND_EPS,
        hessact_ovsamp=stg.HESSACT_OVSAMP, hessact_pow_n=stg.HESSACT_POW_N)
    stg.tmap = MAPS.LinearTransportMap.build_from_Gaussian( laplace_approx )

    stg.approx_base_distribution = DIST.PullBackTransportMapDistribution(
        stg.tmap, stg.target_distribution)
    stg.approx_target_distribution = DIST.PushForwardTransportMapDistribution(
        stg.tmap, stg.base_distribution)

    ################ Identity Laplace ##################
    stg.laplace_id = MAPS.LinearTransportMap(laplace_approx.mu, np.eye(dim))

    # Store Laplace approximations:
    with open(OUT_FNAME, 'wb') as out_stream:
        pickle.dump(stg, out_stream)
finally:
    if INTERACTIVE:
        from IPython import embed
        embed()
