#!/usr/bin/env python

#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps.  If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Author: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#

from __future__ import print_function

import sys
import getopt
import os
import os.path
import shutil
import time
import datetime
import logging
import dill as pickle
import numpy as np
import scipy.stats as stats
import TransportMaps as TM
import TransportMaps.CLI as TMCLI
import TransportMaps.Maps as MAPS
import TransportMaps.Distributions as DIST
import TransportMaps.Diagnostics as DIAG
import TransportMaps.XML as TMXML
import TransportMaps.Algorithms.Adaptivity as ALGADPT

import numpy.random as npr
npr.seed(0)

sys.path.append(os.getcwd())

# Data storage object
class DataStorageObject(object):
    pass
stg = DataStorageObject

if sys.version_info[0] == 2:
    input = raw_input

def usage():
    usage_str = """
Usage: tmap-tm [-h -I] 
  --dist=DIST --output=OUTPUT [--base-dist=BASE_DIST]
  (--mtype=MTYPE --span=SPAN --btype=BTYPE --order=ORDER --sparsity=SPARSITY)
    / (--map-descr=MAP_DESCR)
  --qtype=QTYPE --qnum=QNUM
  [--tol=TOL --maxit=MAXIT --reg=REG --ders=DERS --fungrad --hessact]
  [--validator=VNAME --val-eps=EPS --val-cost-fun=CFUN
   --val-max-cost=LIM --val-max-nsamps=LIM --val-stop-on-fcast]
  [--val-saa-eps-rel=EPS --val-saa-upper-mul=UMUL --val-saa-lower-n=LOWN 
   --val-saa-alpha=ALPHA --val-saa-lmb-def=LDEF --val-saa-lmb-max=LMAX]
  [--adapt=none --adapt-tol=TOL --adapt-verbosity=VAL]
  [--adapt-regr=REGR --adapt-regr-reg=REG
   --adapt-regr-tol=TOL --adapt-regr-maxit=MAXIT]
  [--adapt-fv-maxit=MAXIT --adapt-fv-prune-trunc=EPS
   --adapt-fv-avar-trunc=EPS --adapt-fv-coeff-trunc=EPS
   --adapt-fv-ls-maxit=MAXIT --adapt-fv-ls-delta=DEL]
  [--laplace-pull --map-pull=MAP]
  [--overwrite --reload --log=LOG --nprocs=NPROCS --batch=BATCH]
"""
    print(usage_str)

def description():
    docs_monotone_str = \
        '  --mtype=MTYPE           monotone format for the transport\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_MONOTONE,'                          ')
    docs_span_str = \
        '  --span=SPAN             span type for all the components of the map\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_SPAN,'                          ')
    docs_btype_str = \
        '  --btype=BTYPE           basis types for all the components of the map\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_BTYPE,'                          ')
    docs_sparsity_str = \
        '  --sparsity=SPARSITY     sparsity pattern (default: tri) \n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_SPARSITY,'                          ')
    docs_qtype_str = \
        '  --qtype=QTYPE           quadrature type for the discretization of ' + \
        'the KL-divergence\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_QTYPE,'                          ')
    docs_ders_str = \
        '  --ders=DERS             derivatives to be used in the optimization\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_DERS,'                          ')
    docs_validator_str = \
        '  --validator=VNAME       valiator to be used (default: none)\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_VALIDATOR,'                          ')
    docs_cost_function_str = \
        '  --val-cost-fun=CFUN     cost function (default: tot-time)\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_COST_FUNCTION,'                          ')
    docs_adaptivity_str = \
        '  --adapt=ADAPT           adaptivity algorithm for map construction\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_ADAPTIVITY,'                          ')
    docs_regression_adaptivity_str = \
        '  --adapt-regr=REGR       regression algorithm to be used within adaptivity\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_REGRESSION_ADAPTIVITY,
                                  '                          ')
    docs_log_str = \
        '  --log=LOG               log level (default=30). Uses package logging.\n' + \
        TMCLI.print_avail_options(TMCLI.AVAIL_LOGGING,'                          ')

    docs_str = """DESCRIPTION
Given a file (--dist) storing the target distribution, produce the transport map that
pushes forward the base distribution (default: standard normal) to the target distribution.
All files involved are stored and loaded using the python package dill.

OPTIONS - input/output:
  --dist=DIST             path to the file containing the target distribution 
  --output=OUTPUT         path to the output file containing the transport map,  
                          the base distribution, the target distribution and all 
                          the additional parameters used for the construction 
  --base-dist=BASE_DIST   path to the file containing the base distribution
                          (default: a standard normal of suitable dimension)
OPTIONS - map description (using default maps):
""" + docs_monotone_str + docs_span_str + docs_btype_str + \
"""  --order=ORDER           order of the transport map
""" + docs_sparsity_str + \
"""OPTIONS - map description (manual):
  --map-descr=MAP_DESCR   XML file containing the skeleton of the transport map
OPTIONS - solver:
""" + docs_qtype_str + \
"""  --qnum=QNUM             quadrature level
  --tol=TOL               kl minimization tolerance (default: 1e-4)
  --maxit=MAXIT           maximum number of iterations for kl minimization
  --reg=REG               a float L2 regularization parameter
                          (default: no regularization)
""" + docs_ders_str + \
"""  --fungrad               whether the distributions provide a method to compute
                          the log pdf and its gradient at the same time
  --hessact               whether to use the action of the Hessian
""" + docs_validator_str + \
"""  --val-eps=EPS           target tolerance for solution of the stochastic program
""" + docs_cost_function_str + \
"""  --val-max-cost=LIM      total cost limit (default: 1000)
  --val-max-nsamps=LIM    maximum number of samples to use in the approximation of
                          the expecations (default: infinity)
  --val-stop-on-fcast     whether to stop on a forecast to exceed the cost limit
                          (by default it stops only after exceeding the cost limit)
  --val-saa-eps-rel=EPS   [SAA] relative error to be used (--val-eps is absolute)
  --val-saa-upper-mult=VAL [SAA] upper multiplier (default: 10)
  --val-saa-lower-n=VAL   [SAA] number of samples for lower bound (default: 2)
  --val-saa-alpha=VAL     [SAA] quantile (defalt: 0.05)
  --val-saa-lmb-def=VAL   [SAA] default sample multiplier (default: 2)
  --val-saa-lmb-max=VAL   [SAA] maximum sample multiplier (default: 10)
""" + docs_adaptivity_str + \
"""  --adapt-tol=TOL         target variance diagnostic tolerance
  --adapt-verbosity=VAL   This regulates the amount of information printed by the logger.
                          Values are >0 with higher values corresponding to higher verbosity.
                          Default is 0.
""" + docs_regression_adaptivity_str + \
"""  --adapt-regr-reg=REG    regularization to be used in regression
  --adapt-regr-tol=TOL    regression tolerance
  --adapt-regr-maxit=MAXIT maximum number of iteration in regression
  --adapt-fv-maxit=MAXIT [first variation] maximum number of iterations (default: 20)
  --adapt-fv-prune-trunc=EPS [first variation] prune truncation parameter (default: .01)
  --adapt-fv-avar-trunc=EPS [first variation] active variables trunc parameter (default: .1)
  --adapt-fv-coeff-trunc=EPS [first variation] coefficient truncation parameter
                          (default: .01)
  --adapt-fv-ls-maxit=MAXIT [first variation] maximum number of line search
                          iterations (default: 20)
  --adapt-fv-ls-delta=DEL [first variation] initial step size for line search
  --laplace-pull          whether to precondition pulling back the target through
                          its Laplace approximation
  --map-pull=MAP          path to file containing a map through which to pullback 
                          the target (this is done before pulling back thorugh
                          the Laplace, if --laplace-pull is provided).
                          The file may cointain just the map or may be the output
                          of any other map construction scripts (tmap-tm, ...)
  --overwrite             overwrite file if it exists
  --reload                reload file if it exists
""" + docs_log_str + \
"""  --nprocs=NPROCS         number of processors to be used (default=1)
  --batch=BATCH           batch size (approximate maximum number of floats in memory)
OPTIONS - other:
  -I                      enter interactive mode after finishing
  -h                      print this help
"""
    print(docs_str)

def full_usage():
    usage()

def full_doc():
    full_usage()
    description()

##################### INPUT PARSING #####################
argv = sys.argv[1:]
INTERACTIVE = False
# I/O
DIST_FNAME = None
OUT_FNAME = None
BASE_DIST_FNAME = None
# Map type
MONOTONE = None
SPAN = None
BTYPE = None
ORDER = None
SPARSITY = 'tri'
MAP_DESCR = None
# Validation
VALIDATOR = 'none'
VAL_EPS = 1e-2
VAL_COST_FUN = 'tot-time'
VAL_MAX_COST = 1000
VAL_MAX_NSAMPS = np.inf
VAL_STOP_ON_FCAST = False
# Sample average approximation validator
VAL_SAA_EPS_REL = 1e-2
VAL_SAA_UPPER_MULT = 10
VAL_SAA_LOWER_N = 2
VAL_SAA_ALPHA = 0.05
VAL_SAA_LMB_DEF = 2
VAL_SAA_LMB_MAX = 10
# Adaptivity
ADAPT = 'none'
ADAPT_TOL = 5e-2
ADAPT_VERB = 0
ADAPT_REGR = 'none'
ADAPT_REGR_REG = None
ADAPT_REGR_TOL = 1e-4
ADAPT_REGR_MAX_IT = 100
ADAPT_FV_MAX_IT = 20
ADAPT_FV_PRUNE_TRUNC = .1
ADAPT_FV_AVAR_TRUNC = .2
ADAPT_FV_COEFF_TRUNC = 1e-2
ADAPT_FV_LS_MAXIT = 20
ADAPT_FV_LS_DELTA = 2.
# Quadrature type
stg.QTYPE = None
stg.QNUM = None
# Solver options
stg.TOL = 1e-4
stg.MAXIT = 100
stg.REG = None
stg.DERS = 2
stg.FUNGRAD = False
stg.HESSACT = False
# Pre-pull Laplace
stg.LAPLACE_PULL = False
stg.MAP_PULL = None
# Overwriting/reloading
OVERWRITE = False
RELOAD = False
# Logging
LOGGING_LEVEL = 30 # Warnings
# Parallelization
NPROCS = 1
BATCH_SIZE = int(1e9)
try:
    opts, args = getopt.getopt(
        argv,"hI",
        [
            # I/O
            "dist=", "output=", "base-dist=",
            # Map type
            "mtype=", "span=", "btype=", "order=", "sparsity=",
            "map-descr=",
            # Quadrature type
            "qtype=", "qnum=",
            # Solver options
            "tol=", "maxit=", "reg=", "ders=", "fungrad", "hessact",
            # Validation
            "validator=", "val-eps=",
            "val-cost-fun=", "val-max-cost=",
            "val-max-nsamps=", "val-stop-on-fcast",
            # Sample average approximation validator
            "val-saa-eps-rel=", "val-saa-upper-mult=", "val-saa-lower-n=",
            "val-saa-alpha=", "val-saa-lmb-def=", "val-saa-lmb-max=",
            # Adaptivity options
            "adapt=", "adapt-tol=", "adapt-verbosity=",
            "adapt-regr=", "adapt-regr-reg=", "adapt-regr-tol=", "adapt-regr-maxit=",
            "adapt-fv-maxit=", "adapt-fv-prune-trunc=",
            "adapt-fv-avar-trunc=", "adapt-fv-coeff-trunc=",
            "adapt-fv-ls-maxit=", 'adapt-fv-ls-delta=',
            # Whether to pre-pull through Laplace
            "laplace-pull", "map-pull=",
            # Overwriting and reloading
            "overwrite", "reload",
            # Logging
            "log=",
            # Parallelization and batching option
            "nprocs=", "batch="
        ])
except getopt.GetoptError:
    full_usage()
    raise
for opt, arg in opts:
    if opt == '-h':
        full_doc()
        sys.exit()
        
    # Interactive
    elif opt == '-I':
        INTERACTIVE = True
        
    # I/O
    elif opt in ['--dist']:
        DIST_FNAME = arg
    elif opt in ['--output']:
        OUT_FNAME = arg
    elif opt in ['--base-dist']:
        BASE_DIST_FNAME = arg
        
    # Map type
    elif opt in ['--mtype']:
        MONOTONE = arg
    elif opt in ['--span']:
        SPAN = arg
    elif opt in ['--btype']:
        BTYPE = arg        
    elif opt in ['--order']:
        ORDER = int(arg)
    elif opt in ['--sparsity']:
        if arg not in TMCLI.AVAIL_SPARSITY:
            usage()
            tstamp_print("ERROR: Argument %s for --sparsity not recognized" % arg)
            sys.exit(3)
        SPARSITY = arg
    elif opt in ['--map-descr']:
        MAP_DESCR = arg

    # Quadrature type
    elif opt in ['--qtype']:
        stg.QTYPE = int(arg)
    elif opt in ['--qnum']:
        stg.QNUM = [int(q) for q in arg.split(',')]
        
    # Solver options
    elif opt in ['--tol']:
        stg.TOL = float(arg)
    elif opt == '--maxit':
        stg.MAXIT = int(arg)
    elif opt in ['--reg']:
        stg.REG = [ {'type': 'L2', 'alpha': float(q)}
                    for q in arg.split(',') ]
    elif opt in ['--ders']:
        stg.DERS = int(arg)
    elif opt == '--fungrad':
        stg.FUNGRAD = True
    elif opt == '--hessact':
        stg.HESSACT = True

    # Validation options
    elif opt == '--validator':
        if  arg not in TMCLI.AVAIL_VALIDATOR:
            usage()
            tstamp_print("ERROR: Argument %s for --validator not recognized" % arg)
            sys.exit(3)
        VALIDATOR = arg
    elif opt == '--val-eps':
        VAL_EPS = float(arg)
    elif opt == '--val-cost-fun':
        if arg not in TMCLI.AVAIL_COST_FUNCTION:
            usage()
            tstamp_print("ERROR: Argument %s for --val-cost-fun not recognized" % arg)
            sys.exit(3)
        VAL_COST_FUN = arg
    elif opt == '--val-max-cost':
        VAL_MAX_COST = float(arg)
    elif opt == '--val-max-nsamps':
        VAL_MAX_NSAMPS = int(arg)
    elif opt == '--val-stop-on-fcast':
        VAL_STOP_ON_FCAST = True
    # Sample average approximation options
    elif opt == '--val-saa-eps-rel':
        VAL_SAA_EPS_REL = float(arg)
    elif opt == '--val-saa-upper-mult':
        VAL_SAA_UPPER_MULT = float(arg)
    elif opt == '--val-saa-lower-n':
        VAL_SAA_LOWER_N = int(arg)
    elif opt == '--val-saa-alpha':
        VAL_SAA_ALPHA = float(alpha)
    elif opt == '--val-saa-lmb-def':
        VAL_SAA_LMB_DEF = float(arg)
    elif opt == '--val-saa-lmb-max':
        VAL_SAA_LMB_MAX = float(arg)
        
    # Adaptivity options
    elif opt == '--adapt':
        ADAPT = arg
    elif opt == '--adapt-tol':
        ADAPT_TOL = float(arg)
    elif opt == '--adapt-verbosity':
        ADAPT_VERB = int(arg)
    elif opt == '--adapt-regr':
        if arg not in TMCLI.AVAIL_REGRESSION_ADAPTIVITY:
            usage()
            tstamp_print("ERROR: Argument %s for --adapt-regr not recognized" % arg)
            sys.exit(3)
        ADAPT_REGR = arg
    elif opt == '--adapt-regr-reg':
        ADAPT_REGR_REG = [ {'type': 'L2', 'alpha': float(q)}
                           for q in arg.split(',') ]
    elif opt == '--adapt-regr-tol':
        ADAPT_REGR_TOL = float(arg)
    elif opt == '--adapt-regr-maxit':
        ADAPT_REGR_MAX_IT = int(arg)
    elif opt == '--adapt-fv-maxit':
        ADAPT_FV_MAX_IT = int(arg)
    elif opt == '--adapt-fv-prune-trunc':
        ADAPT_FV_PRUNE_TRUNC = float(arg)
    elif opt == '--adapt-fv-avar-trunc':
        ADAPT_FV_AVAR_TRUNC = float(arg)
    elif opt == '--adapt-fv-coeff-trunc':
        ADAPT_FV_COEFF_TRUNC = float(arg)
    elif opt == '--adapt-fv-ls-maxit':
        ADAPT_FV_LS_MAXIT = int(arg)
    elif opt == '--adapt-fv-ls-delta':
        ADAPT_FV_LS_DELTA = float(arg)

    # Pre-pull 
    elif opt in ['--laplace-pull']:
        stg.LAPLACE_PULL = True
    elif opt == '--map-pull':
        stg.MAP_PULL = arg

    # Overwriting/reloading
    elif opt == '--overwrite':
        OVERWRITE = True
    elif opt == '--reload':
        RELOAD = True

    # Logging
    elif opt in ['--log']:
        LOGGING_LEVEL = int(arg)

    # Parallelization and batching
    elif opt in ['--nprocs']:
        NPROCS = int(arg)
    elif opt in ['--batch']:
        BATCH_SIZE = int(float(arg))
        
    else:
        raise ValueError("Option %s not recognized" % opt)

def tstamp_print(msg, *args, **kwargs):
    tstamp = datetime.datetime.fromtimestamp(
        time.time()
    ).strftime('%Y-%m-%d %H:%M:%S')
    print(tstamp + " " + msg, *args, **kwargs)
    
# Check for required arguments
if OVERWRITE and RELOAD:
    usage()
    tstamp_print("ERROR: options --overwrite and --reload are mutually esclusive")
    sys.exit(3)
if not RELOAD:    
    if None in [DIST_FNAME, OUT_FNAME]:
        usage()
        tstamp_print("ERROR: Options --dist and --output must be specified")
        sys.exit(3)
    if None in [stg.QTYPE, stg.QNUM]:
        usage()
        tstamp_print("ERROR: Options --qtype and --qnum must be specified")
        sys.exit(3)
    if stg.QTYPE < 3:
        stg.QNUM = stg.QNUM[0]
    map_descr_list = [MONOTONE, SPAN, BTYPE, ORDER]
    if MAP_DESCR is None and None in map_descr_list:
        usage()
        tstamp_print("ERROR: Either options --mtype, --span, --btype, " + \
              "--order are specified or option --map-descr is specified")
        sys.exit(3)
    elif MAP_DESCR is not None and not all([s is None for s in map_descr_list]):
        usage()
        tstamp_print("ERROR: Either options --mtype, --span, --btype, " + \
              "--order are specified or option --map-descr is specified")
        sys.exit(3)
    if ADAPT not in TMCLI.AVAIL_ADAPTIVITY:
        usage()
        tstamp_print("ERROR: adaptivity algorithm not recognized")
        sys.exit(3)

def store(tm, stg, fname):
    precond_map = [ MAPS.IdentityTransportMap(stg.target_distribution.dim) ]
    if stg.LAPLACE_PULL:
        precond_map.append( stg.lapmap )
    if stg.MAP_PULL is not None:
        precond_map.append( stg.map_pull )
    stg.precond_map = MAPS.ListCompositeMap( precond_map )
    
    stg.tmap = MAPS.CompositeMap(stg.precond_map, tm)

    stg.approx_base_distribution = DIST.PullBackTransportMapDistribution(
        stg.tmap, stg.target_distribution)
    stg.approx_target_distribution = DIST.PushForwardTransportMapDistribution(
        stg.tmap, stg.base_distribution)

    # Backup copy
    if os.path.exists(fname):
        shutil.copyfile(fname, fname + '.bak')
    # Store
    with open(fname, 'wb') as out_stream:
        pickle.dump(stg, out_stream)
    # Remove backup
    try:
        os.remove(fname + '.bak')
    except:
        pass
        
if not OVERWRITE and not RELOAD and os.path.exists(OUT_FNAME):
    sel = ''
    while sel not in ['o', 'r', 'q']:
        sel = input("The file %s already exists. " % OUT_FNAME + \
                    "Do you want to overwrite (o), reload (r) or quit (q)? [o/r/q] ")
    if sel == 'o':
        while sel not in ['y', 'n']:
            sel = input("Please, confirm that you want overwrite. [y/n]")
        if sel == 'n':
            tstamp_print("Terminating.")
            sys.exit(0)
    elif sel == 'r':
        RELOAD = True
    else:
        tstamp_print("Terminating.")
        sys.exit(0)
        
try:
    ##################### DATA LOADING #####################
    logging.basicConfig(level=LOGGING_LEVEL)
    TM.setLogLevel(LOGGING_LEVEL)

    # Start mpi pool
    mpi_pool = None
    if NPROCS > 1:
        mpi_pool = TM.get_mpi_pool()
        mpi_pool.start(NPROCS)


    if RELOAD:
        with open(OUT_FNAME, 'rb') as istr:
            stg = pickle.load(istr)
        
        # Set callback in builder
        stg.builder.callback = store
        stg.builder.callback_kwargs = {'stg': stg, 'fname': OUT_FNAME}
            
        # Set mpi_pool in the builder
        stg.builder.set_mpi_pool(mpi_pool)

    else:
        # Load target distribution
        with open(DIST_FNAME,'rb') as istr:
            stg.target_distribution = pickle.load(istr)
        dim = stg.target_distribution.dim

        # Load base distribution
        if BASE_DIST_FNAME is None:
            stg.base_distribution = DIST.StandardNormalDistribution(dim)
        else:
            with open(BASE_DIST_FNAME,'rb') as istr:
                stg.base_distribution = pickle.load(istr)

        # Instantiate Transport Map
        if MAP_DESCR is not None:
            if ADAPT == 'sequential':
                usage()
                tstamp_print("ERROR: option --adapt=sequential is not available with --map-descr (yet..)")
                sys.exit(3)
            tm_approx = TMXML.load_xml(MAP_DESCR)
        else:
            if MONOTONE == 'linspan':
                if SPARSITY == 'tri':
                    map_constructor = \
                        TM.Default_IsotropicMonotonicLinearSpanTriangularTransportMap
                elif SPARSITY == 'diag':
                    map_constructor = \
                        TM.Default_IsotropicMonotonicLinearSpanDiagonalTransportMap
            elif MONOTONE == 'intexp':
                if SPARSITY == 'tri':
                    map_constructor = \
                        TM.Default_IsotropicIntegratedExponentialTriangularTransportMap
                elif SPARSITY == 'diag':
                    map_constructor = \
                        TM.Default_IsotropicIntegratedExponentialDiagonalTransportMap
            elif MONOTONE == 'intsq':
                if SPARSITY == 'tri':
                    map_constructor = \
                        TM.Default_IsotropicIntegratedSquaredTriangularTransportMap
                elif SPARSITY == 'diag':
                    map_constructor = \
                        TM.Default_IsotropicIntegratedSquaredDiagonalTransportMap
            else:
                raise ValueError("Monotone type not recognized (linspan|intexp|intsq)")
            if ADAPT in ['none','fv']:
                tm_approx = map_constructor(dim, ORDER, span=SPAN, btype=BTYPE)
                logging.info("Number coefficients: %d" % tm_approx.n_coeffs)
            elif ADAPT in ['sequential', 'tol-sequential']:
                tm_approx = [map_constructor(dim, o, span=SPAN, btype=BTYPE) for o in range(1,ORDER+1)]
                n_coeffs = sum( tm.n_coeffs for tm in tm_approx )
                logging.info("Number coefficients: %d" % n_coeffs )

        # Set up validator
        if VALIDATOR == 'none':
            validator = None
        elif VALIDATOR == 'saa':
            if VAL_COST_FUN == 'tot-time':
                cost_fun = TM.total_time_cost_function
            validator = DIAG.SampleAverageApproximationKLMinimizationValidator(
                VAL_EPS, VAL_SAA_EPS_REL, cost_fun, VAL_MAX_COST,
                max_nsamps=VAL_MAX_NSAMPS, stop_on_fcast=VAL_STOP_ON_FCAST,
                upper_mult=VAL_SAA_UPPER_MULT, lower_n=VAL_SAA_LOWER_N,
                alpha=VAL_SAA_ALPHA, lmb_def=VAL_SAA_LMB_DEF,
                lmb_max=VAL_SAA_LMB_MAX)

        tar = stg.target_distribution

        # Map pullback
        stg.map_pull = None
        if stg.MAP_PULL is not None:
            with open(stg.MAP_PULL, 'rb') as istr:
                stg.map_pull = pickle.load(istr)
            if not issubclass(type(stg.map_pull), MAPS.Map):
                stg.map_pull = stg.map_pull.tmap
            tar = DIST.PullBackTransportMapDistribution(stg.map_pull, tar)

        # Laplace pullback
        stg.lapmap = None
        if stg.LAPLACE_PULL:
            laplace_approx = TM.laplace_approximation( tar )
            stg.lapmap = MAPS.LinearTransportMap.build_from_Gaussian( laplace_approx )
            tar = DIST.PullBackTransportMapDistribution( stg.lapmap, tar )

        if ADAPT in ['none', 'fv']:
            solve_params = {
                'qtype': stg.QTYPE,
                'qparams': stg.QNUM,
                'tol': stg.TOL,
                'maxit': stg.MAXIT,
                'regularization': None if stg.REG is None else stg.REG[0],
                'ders': stg.DERS,
                'fungrad': stg.FUNGRAD,
                'hessact': stg.HESSACT,
                'batch_size': BATCH_SIZE,
                'mpi_pool': mpi_pool
            }
        elif ADAPT in ['sequential', 'tol-sequential']:
            if stg.REG is not None and len(stg.REG) == 1:
                stg.REG = stg.REG * ORDER
            solve_params = {'solve_params_list': []}
            for i, tm in enumerate(tm_approx):
                solve_params['solve_params_list'].append( {
                    'qtype': stg.QTYPE, 'qparams': stg.QNUM,
                    'tol': stg.TOL, 'maxit': stg.MAXIT,
                    'regularization': None if stg.REG is None else stg.REG[i],
                    'ders': stg.DERS, 'fungrad': stg.FUNGRAD, 'hessact': stg.HESSACT,
                    'batch_size': BATCH_SIZE, 'mpi_pool': mpi_pool
                } )
            if ADAPT == 'tol-sequential':
                solve_params['tol'] = ADAPT_TOL
                solve_params['var_diag_params'] = {
                    'qtype': MONITOR_QTYPE, 'qparams': MONITOR_QNUM,
                    'mpi_pool_tuple': (None, mpi_pool) }

        # Define callback (storage) arguments
        callback_kwargs = {'stg': stg, 'fname': OUT_FNAME}

        # Set up adaptivity algorithm
        if ADAPT == 'none':
            stg.builder = ALGADPT.KullbackLeiblerBuilder(
                stg.base_distribution, tar,
                tm_approx, solve_params, validator,
                callback=store, callback_kwargs=callback_kwargs,
                verbosity=ADAPT_VERB)
        else:
            regression_params = {
                'regularization': ADAPT_REGR_REG,
                'tol': ADAPT_REGR_TOL,
                'maxit': ADAPT_REGR_MAX_IT }
            if ADAPT_REGR == 'none':
                regressor = ALGADPT.L2RegressionBuilder(regression_params)
            elif ADAPT_REGR == 'tol-sequential':
                raise NotImplementedError(
                    "--adapt-regr=tol-sequential not supported by this script")

            if ADAPT == 'sequential':
                stg.builder = ALGADPT.SequentialKullbackLeiblerBuilder(
                    stg.base_distribution, tar,
                    tm_approx, solve_params, validator=validator,
                    callback=store, callback_kwargs=callback_kwargs,
                    verbosity=ADAPT_VERB)
            elif ADAPT == 'tol-sequential':
                stg.builder = ALGADPT.ToleranceSequentialKullbackLeiblerBuilder(
                    stg.base_distribution, tar,
                    tm_approx, solve_params, validator=validator, tol=ADAPT_TOL,
                    callback=store, callback_kwargs=callback_kwargs,
                    verbosity=ADAPT_VERB)
            elif ADAPT == 'fv':
                line_search_params = {'maxiter': ADAPT_FV_LS_MAXIT,
                                      'delta': ADAPT_FV_LS_DELTA}
                stg.builder = ALGADPT.FirstVariationKullbackLeiblerBuilder(
                    stg.base_distribution, tar,
                    tm_approx, validator=validator,
                    eps_bull=ADAPT_TOL,
                    regression_builder=regressor,
                    solve_params=solve_params,
                    line_search_params=line_search_params,
                    max_it=ADAPT_FV_MAX_IT,
                    use_fv_hess = stg.DERS>1 and not stg.HESSACT, # disabled if action
                    prune_trunc=ADAPT_FV_PRUNE_TRUNC,
                    avar_trunc=ADAPT_FV_AVAR_TRUNC, coeff_trunc=ADAPT_FV_COEFF_TRUNC,
                    callback=store, callback_kwargs=callback_kwargs,
                    verbosity=ADAPT_VERB)

    try:
        (tm_approx, log) = stg.builder.solve(reloading=RELOAD)
    finally:
        if mpi_pool is not None:
            mpi_pool.stop()
            
    store(tm_approx, stg, OUT_FNAME)

finally:
    if INTERACTIVE:
        from IPython import embed
        embed()
