#!/usr/bin/env python

#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps.  If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Author: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#

from __future__ import print_function

import sys
import getopt
import os
import os.path
import dill as pickle
import time
import datetime
import logging
import numpy as np
import scipy.stats as stats
import TransportMaps as TM
import TransportMaps.CLI as TMCLI


sys.path.append(os.getcwd())

# Data storage object
stg = type('', (), {})()

def usage():
    usage_str = """
Usage: tmap-max-likelihood [-h -f -I] --output=OUTPUT
  (--dist=DIST / --log-lkl=LKL)
  [--tol=TOL --ders=DERS --fungrad --log=LOG]
"""
    print(usage_str)

def full_usage():
    usage()

def description():
    docs_ders_str = \
        '  --ders=DERS             derivatives to be used in the optimization\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_DERS,'                          ')
    docs_log_str = \
        '  --log=LOG               log level (default=30). Uses package logging.\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_LOGGING,'                          ')
    docs_str = """DESCRIPTION
Given a file (--dist) storing the target distribution, generate the linear map
corresponding to the Laplace approximation of it.
All files involved are stored and loaded using the python package dill.

OPTIONS
  --output=OUTPUT         output file containing the linear transport map,
                          the base distribution (standard normal), the target distribution
                          and all the additional parameters used for the
                          construction
  --dist=DIST             file containing the target Bayesian posterior distribution
                            of type BayesPosteriorDistribution
                            (exclusive with --log-lkl)
  --log-lkl=LKL           file containing the log-likelihood function 
                            (exclusive with --dist) 
  --tol=TOL               optimization tolerance (default: 1e-4)
""" + docs_ders_str + docs_log_str + \
"""  --fungrad              whether the distributions provide a method to compute
                          the log pdf and its gradient at the same time
  -f                      force overwrite of OUTPUT file
  -I                      enter interactive mode after finishing
  -h                      print this help
"""
    print(docs_str)

def full_doc():
    full_usage()
    description()

##################### INPUT PARSING #####################
argv = sys.argv[1:]
INTERACTIVE = False
# I/O
OUT_FNAME = None
DIST_FNAME = None
LKL_FNAME = None
FORCE = False
# Solver options
stg.TOL = 1e-4
stg.DERS = 2
stg.FUNGRAD = False
# Logging
LOGGING_LEVEL = 30 # Warnings
try:
    opts, args = getopt.getopt(argv,"hfI",["output=", "dist=", "log-lkl=",
                                         "tol=", "ders=", "fungrad",
                                         "log="])
except getopt.GetoptError as e:
    full_usage()
    print(e)
for opt, arg in opts:
    if opt == '-h':
        full_doc()
        sys.exit()
    # I/O
    elif opt == "--output":
        OUT_FNAME = arg
    elif opt == "--dist":
        DIST_FNAME = arg
    elif opt == "--log-lkl":
        LKL_FNAME = arg
    # Solver options
    elif opt == '--tol':
        stg.TOL = float(arg)
    elif opt == '--ders':
        stg.DERS = int(arg)
    elif opt == '--fungrad':
        stg.FUNGRAD = True
    # Logging
    elif opt == '--log':
        LOGGING_LEVEL = int(arg)
    # Force overwrite
    elif opt == '-f':
        FORCE = True
    # Interactive
    elif opt == "-I":
        INTERACTIVE = True
    else:
        raise RuntimeError("Input option %s not recognized." % opt)

def tstamp_print(msg, *args, **kwargs):
    tstamp = datetime.datetime.fromtimestamp(
        time.time()
    ).strftime('%Y-%m-%d %H:%M:%S')
    print(tstamp + " " + msg, *args, **kwargs)

if (DIST_FNAME is None and LKL_FNAME is None) or \
   (DIST_FNAME is not None and LKL_FNAME is not None):
    full_usage()
    tstamp_print("ERROR: Either --dist or --log-lkl must be specified")
    sys.exit(3)

if OUT_FNAME is None:
    full_usage()
    tstamp_print("ERROR: Option --output must be specified")
    sys.exit(3)

if not FORCE and os.path.exists(OUT_FNAME):
    sel = ''
    while sel not in ['y', 'Y', 'n', 'N', 'q']:
        if sys.version_info[0] == 3:
            sel = input("The file %s already exists. " % OUT_FNAME + \
                        "Do you want to overwrite? [y/N/q] ")
        else:
            sel = raw_input("The file %s already exists. " % OUT_FNAME + \
                            "Do you want to overwrite? [y/N/q] ")
    if sel == 'n' or sel == 'N' or sel == 'q':
        tstamp_print("Terminating.")
        sys.exit(0)

try:
    TM.setLogLevel(LOGGING_LEVEL)

    if DIST_FNAME is not None:
        with open(DIST_FNAME, 'rb') as istr:
            stg.target_distribution = pickle.load(istr)
        stg.log_likelihood = stg.target_distribution.logL
    elif LKL_FNAME is not None:
        with open(LKL_FNAME, 'rb') as istr:
            stg.log_likelihood = pickle.load(istr)

    # Maximum likelihood
    stg.max_likelihood = TM.maximum_likelihood(
        stg.log_likelihood, tol=stg.TOL, ders=stg.DERS, fungrad=stg.FUNGRAD)

    # Store
    with open(OUT_FNAME, 'wb') as ostr:
        pickle.dump(stg, ostr)

finally:
    if INTERACTIVE:
        from IPython import embed
        embed()