#!/usr/bin/env python

#
# This file is part of TransportMaps.
#
# TransportMaps is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# TransportMaps is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with TransportMaps.  If not, see <http://www.gnu.org/licenses/>.
#
# Transport Maps Library
# Copyright (C) 2015-2018 Massachusetts Institute of Technology
# Uncertainty Quantification group
# Department of Aeronautics and Astronautics
#
# Author: Transport Map Team
# Website: transportmaps.mit.edu
# Support: transportmaps.mit.edu/qa/
#

from __future__ import print_function

import sys
import getopt
import os
import os.path
import shutil
import time
import datetime
import logging
import dill as pickle
import numpy as np
import scipy.stats as stats
import TransportMaps as TM
import TransportMaps.CLI as CLI
import TransportMaps.Maps as MAPS
import TransportMaps.Distributions as DIST
import TransportMaps.Distributions.Decomposable as DECDIST
import TransportMaps.Algorithms.SequentialInference as ALGSI
import TransportMaps.Algorithms.Adaptivity as ALGADPT
import TransportMaps.XML as TMXML

sys.path.append(os.getcwd())

# Make stdout unbuffered
class Unbuffered(object):
    def __init__(self, stream):
        self.stream = stream
    def write(self, data):
        self.stream.write(data)
        self.stream.flush()
    def writelines(self, datas):
        self.stream.writelines(datas)
        self.stream.flush()
    def __getattr__(self, attr):
        return getattr(self.stream, attr)
sys.stdout = Unbuffered(sys.stdout)

# Send mail on error/termination routine
def sendmail(to, sbj, txt):
    import socket
    import smtplib
    from email.mime.text import MIMEText
    msg = MIMEText(txt)
    msg['subject'] = sbj
    msg['from'] = socket.gethostname()
    msg['to'] = to
    s = smtplib.SMTP('localhost')
    s.sendmail(msg['from'], to, msg.as_string())
    s.quit()


# Data storage object
stg = type('', (), {})()

def usage():
    usage_str = """
Usage: tmap-sequential-tm [-h -f -I] 
  --dist=DIST --output=OUTPUT 
  (--mtype=MTYPE --span=SPAN --btype=BTYPE --order=ORDER)
    / (--map-0-descr=MAP_DESCR --map-descr=MAP_DESCR)
  --qtype=QTYPE --qnum=QNUM
  [--tol=TOL --maxit=MAXIT --with-reg=REG --ders=DERS --fungrad]
  [--adapt=none --adapt-tol=TOL]
  [--monitor-qtype=QTYPE --monitor-qnum=QNUM]
  [(--hyper-mtype=MTYPE --hyper-span=SPAN --hyper-btype=BTYPE --hyper-order=ORDER)
    / (--hyper-map-descr=MAP_DESCR)
   --hyper-qtype=QTYPE --hyper-qnum=QNUM
   --hyper-tol=TOL --hyper-maxit=MAXIT --hyper-with-reg=REG]
  [--hyper-adapt=none --hyper-adapt-tol=TOL]
  [--hyper-monitor-qtype=QTYPE --hyper-monitor-qnum=QNUM]
  [--safe-mode=LAG --reload --log=LOG --nprocs=NPROCS --batch=BATCH --email=EMAIL]
"""
    print(usage_str)

def description():
    docs_monotone_str = \
        '  --mtype=MTYPE           monotone format for the transport\n' + \
        CLI.print_avail_options(CLI.AVAIL_MONOTONE,'                          ')
    docs_span_str = \
        '  --span=SPAN             span type for all the components of the map\n' + \
        CLI.print_avail_options(CLI.AVAIL_SPAN,'                          ')
    docs_btype_str = \
        '  --btype=BTYPE           basis types for all the components of the map\n' + \
        CLI.print_avail_options(CLI.AVAIL_BTYPE,'                          ')
    docs_qtype_str = \
        '  --qtype=QTYPE           quadrature type for the discretization of ' + \
        'the KL-divergence\n' + \
        CLI.print_avail_options(CLI.AVAIL_QTYPE,'                          ')
    docs_ders_str = \
        '  --ders=DERS             derivatives to be used in the optimization\n' + \
        CLI.print_avail_options(CLI.AVAIL_DERS,'                          ')
    docs_adaptivity_str = \
        '  --adapt=ADAPT           adaptivity algorithm for map construction\n' + \
        CLI.print_avail_options(CLI.AVAIL_ADAPTIVITY,'                          ')
    docs_monitor_qtype = \
        '  --monitor-qtype=QTYPE   quadrature type for the discretization of ' + \
        'the variance diagnostic\n' + \
        CLI.print_avail_options(CLI.AVAIL_QTYPE,'                          ')
    docs_hyper_monotone_str = \
        '  --hyper-mtype=MTYPE     monotone format for the transport\n' + \
        CLI.print_avail_options(CLI.AVAIL_MONOTONE,'                          ')
    docs_hyper_span_str = \
        '  --hyper-span=SPAN       span type for all the components of the map\n' + \
        CLI.print_avail_options(CLI.AVAIL_SPAN,'                          ')
    docs_hyper_btype_str = \
        '  --hyper-btype=BTYPE     basis types for all the components of the map\n' + \
        CLI.print_avail_options(CLI.AVAIL_BTYPE,'                          ')
    docs_hyper_qtype_str = \
        '  --hyper-qtype=QTYPE     quadrature type for the discretization of ' + \
        'the L2-norm\n' + \
        CLI.print_avail_options(CLI.AVAIL_QTYPE,'                          ')
    docs_hyper_monitor_qtype_str = \
        '  --hyper-monitor-qtype=QTYPE quadrature type for the discretization of ' + \
        'the L2-norm\n' + \
        CLI.print_avail_options(CLI.AVAIL_QTYPE,'                          ')
    docs_hyper_adaptivity_str = \
        '  --hyper-adapt=ADAPT     adaptivity algorithm for hyper-map regression\n' + \
        CLI.print_avail_options(CLI.AVAIL_REGRESSION_ADAPTIVITY,'                          ')
    docs_log_str = \
        '  --log=LOG               log level (default=30). Uses package logging.\n' + \
        CLI.print_avail_options(CLI.AVAIL_LOGGING,'                          ')

    docs_str = """DESCRIPTION
Given a file (--dist) storing the target distribution, produce the transport map that
pushes forward the base distribution (default: standard normal) to the target distribution,
using the algorithm for sequential low-dimensional couplings.
The input distribution must be SequentialHiddenMarkcovChainDistribution with
hdim hyperparameters and sdim state dimension.
All files involved are stored and loaded using the python package dill.

OPTIONS - input/output:
  --dist=DIST             path to the file containing the target distribution 
  --output=OUTPUT         path to the output file containing the transport map,  
                          the base distribution, the target distribution and all 
                          the additional parameters used for the construction 
OPTIONS - map description (using default maps):
""" + docs_monotone_str + docs_span_str + docs_btype_str + \
"""  --order=ORDER           order of the transport map
OPTIONS - map description (manual):
  --map-0-descr=MAP_DESCR XML file with the skeleton of the zero-th transport map (hdim+sdim)
  --map-descr=MAP_DESCR   XML file with the skeleton of the transport map (hdim+2*sdim)
OPTIONS - KL-minimization solver:
""" + docs_qtype_str + \
"""  --qnum=QNUM             quadrature level
  --tol=TOL               optimization tolerance (default: 1e-4)
  --maxit=MAXIT           maximum number of iterations
  --reg=REG               a float L2 regularization parameter
                          (default: no regularization)
""" + docs_ders_str + \
"""  --fungrad               whether the distributions provide a method to compute
                          the log pdf and its gradient at the same time
""" + docs_adaptivity_str + \
"""  --adapt-tol=TOL         target tolerance for map adaptivity
OPTIONS - KL-divergence monitoring:
""" + docs_monitor_qtype + \
"""  --monitor-qnum=QNUM     quadrature level
OPTIONS - hyper-parameters map description (using default maps):
""" + docs_hyper_monotone_str + docs_hyper_span_str + docs_hyper_btype_str + \
"""  --hyper-order=ORDER     order of the transport map
OPTIONS - hyper-parameters map description (manual):
  --hyper-map-descr=MAP_DESCR   XML file containing the skeleton of the transport map
OPTIONS - regression solver:
""" + docs_hyper_qtype_str + \
"""  --hyper-qnum=QNUM       quadrature level
  --hyper-tol=TOL         optimization tolerance (default: 1e-4)
  --hyper-maxit=MAXIT     maximum number of iterations
  --hyper-reg=REG         a float L2 regularization parameter
                          (default: no regularization)
""" + docs_hyper_adaptivity_str + \
"""  --hyper-adapt-tol=TOL   target tolerance for adaptivity algorithm
OPTIONS - regression monitoring:
""" + docs_hyper_monitor_qtype_str + \
"""  --hyper-monitor-qnum=QNUM quadrature level
OPTIONS - other:
  --continue-on-error     whether to continue the assimilation when an error is raised,
                          by taking predefined back-up actions
  --safe-mode=LAG         store intermediate maps every LAG iterations (allows for restarting)
                          Default is LAG=0, i.e. do not store.
  --reload                automatically resume from stored data (conflicts with -f)
""" + docs_log_str + \
"""  --email=EMAIL           send an email to the desired address at termination or error 
  --nprocs=NPROCS         number of processors to be used (default=1)
  --batch=BATCH           batch size for function evaluation, gradient
                          evaluation and Hessian evaluation. This size includes the
                          number of coefficients involved in the actual approximation
  -v                      verbose output (not affecting --log)
  -f                      force overwrite of OUTPUT file
  -I                      enter interactive mode after finishing
  -h                      print this help
"""
    print(docs_str)

def full_usage():
    usage()

def full_doc():
    full_usage()
    description()

##################### INPUT PARSING #####################
argv = sys.argv[1:]
INTERACTIVE = False
# I/O
DIST_FNAME = None
OUT_FNAME = None
BASE_DIST_FNAME = None
FORCE = False
# Map type
MONOTONE = None
SPAN = None
BTYPE = None
ORDER = None
MAP_0_DESCR = None
MAP_DESCR = None
ADAPT = 'none'
ADAPT_TOL = None
# Quadrature type
stg.QTYPE = None
stg.QNUM = None
# KL-minimization Solver options
stg.TOL = 1e-4
stg.MAXIT = 100
stg.REG = None
stg.DERS = 2
stg.FUNGRAD = False
# KL-divergence monitoring
MONITOR_QTYPE = None
MONITOR_QNUM = None
# Hyper Map type
HYPER_MONOTONE = None
HYPER_SPAN = None
HYPER_BTYPE = None
HYPER_ORDER = None
HYPER_MAP_DESCR = None
HYPER_ADAPT = 'none'
HYPER_ADAPT_TOL = None
# Hyper Quadrature type
stg.HYPER_QTYPE = None
stg.HYPER_QNUM = None
# Regression Solver options
stg.HYPER_TOL = 1e-4
stg.HYPER_MAXIT = 100
stg.HYPER_REG = None
# Regression monitoring
HYPER_MONITOR_QTYPE = None
HYPER_MONITOR_QNUM = None
# Safe mode
CONTINUE_ON_ERROR = False
SAFE_MODE = 0
RELOAD = False
# Logging
VERBOSE = False
LOGGING_LEVEL = 30 # Warnings
EMAIL = None
# Parallelization
NPROCS = 1
BATCH_SIZE = int(1e9)
try:
    opts, args = getopt.getopt(argv,"hfvI",
                               [
                                   # I/O
                                   "dist=", "output=", "base-dist=",
                                   # Map type
                                   "mtype=", "span=", "btype=", "order=",
                                   "map-0-descr=", "map-descr=",
                                   # Quadrature type
                                   "qtype=", "qnum=",
                                   # KL-minimization solver options
                                   "tol=", "maxit=", "reg=", "ders=", "fungrad",
                                   # Adaptivity options
                                   "adapt=", "adapt-tol=",
                                   # KL-divergence monitoring
                                   "monitor-qtype=", "monitor-qnum=",
                                   # Hyper map type
                                   "hyper-mtype=", "hyper-span=", "hyper-btype=",
                                   "hyper-order=",
                                   "hyper-map-descr=",
                                   # Hyper Quadrature type
                                   "hyper-qtype=", "hyper-qnum=",
                                   # Regression solver options
                                   "hyper-tol=", "hyper-maxit=", "hyper-reg=",
                                   # Regression adaptivity options
                                   "hyper-adapt=", "hyper-adapt-tol=",
                                   # Regression monitoring
                                   "hyper-monitor-qtype=", "hyper-monitor-qnum=",
                                   # Safe mode
                                   "continue-on-error", "safe-mode=", "reload",
                                   # Logging
                                   "log=", "email=",
                                   # Parallelization and batching option
                                   "nprocs=", "batch="
                               ])
except getopt.GetoptError:
    full_usage()
    raise
for opt, arg in opts:
    if opt == '-h':
        full_doc()
        sys.exit()

    # Force overwrite
    elif opt == '-f':
        FORCE = True

    # Verbose
    elif opt == '-v':
        VERBOSE = True
        
    # Interactive
    elif opt == '-I':
        INTERACTIVE = True
        
    # I/O
    elif opt in ['--dist']:
        DIST_FNAME = arg
    elif opt in ['--output']:
        OUT_FNAME = arg
    elif opt in ['--base-dist']:
        BASE_DIST_FNAME = arg
        
    # Map type
    elif opt in ['--mtype']:
        MONOTONE = arg
    elif opt in ['--span']:
        SPAN = [q for q in arg.split(',')]
        if len(SPAN) == 1:
            SPAN = SPAN[0]
    elif opt in ['--btype']:
        BTYPE = [q for q in arg.split(',')]
        if len(BTYPE) == 1:
            BTYPE = BTYPE[0]
    elif opt in ['--order']:
        ORDER = int(arg)
    elif opt in ['--map-0-descr']:
        MAP_0_DESCR = arg
    elif opt in ['--map-descr']:
        MAP_DESCR = arg

    # Quadrature type
    elif opt in ['--qtype']:
        stg.QTYPE = int(arg)
    elif opt in ['--qnum']:
        stg.QNUM = [int(q) for q in arg.split(',')]
        
    # KL-minimization solver options
    elif opt in ['--tol']:
        stg.TOL = float(arg)
    elif opt == '--maxit':
        stg.MAXIT = int(arg)
    elif opt in ['--reg']:
        stg.REG = [ {'type': 'L2',
                     'alpha': float(q)} for q in arg.split(',') ]
    elif opt in ['--ders']:
        stg.DERS = int(arg)
    elif opt == '--fungrad':
        stg.FUNGRAD = True

    # Adaptivity options
    elif opt == '--adapt':
        ADAPT = arg
    elif opt == '--adapt-tol':
        ADAPT_TOL = float(arg)

    # KL-divergence monitoring
    elif opt == '--monitor-qtype':
        MONITOR_QTYPE = int(arg)
    elif opt == '--monitor-qnum':
        MONITOR_QNUM = [int(q) for q in arg.split(',')]

    # Hyper-parameters Map type
    elif opt in ['--hyper-mtype']:
        HYPER_MONOTONE = arg
    elif opt in ['--hyper-span']:
        HYPER_SPAN = arg
    elif opt in ['--hyper-btype']:
        HYPER_BTYPE = arg
    elif opt in ['--hyper-order']:
        HYPER_ORDER = int(arg)
    elif opt in ['--hyper-map-descr']:
        HYPER_MAP_DESCR = arg

    # Hyper-parameters Quadrature type
    elif opt in ['--hyper-qtype']:
        stg.HYPER_QTYPE = int(arg)
    elif opt in ['--hyper-qnum']:
        stg.HYPER_QNUM = [int(q) for q in arg.split(',')]
        
    # Regression solver options
    elif opt in ['--hyper-tol']:
        stg.HYPER_TOL = float(arg)
    elif opt == '--hyper-maxit':
        stg.HYPER_MAXIT = int(arg)
    elif opt in ['--hyper-reg']:
        stg.HYPER_REG = {'type': 'L2',
                   'alpha': float(arg)}

    # Regression adaptivity options
    elif opt in ['--hyper-adapt']:
        HYPER_ADAPT = arg
    elif opt in ['--hyper-adapt-tol']:
        HYPER_ADAPT_TOL = float(arg)

    # Regression monitoring
    elif opt == '--hyper-monitor-qtype':
        HYPER_MONITOR_QTYPE = int(arg)
    elif opt == '--hyper-monitor-qnum':
        HYPER_MONITOR_QNUM = [int(q) for q in arg.split(',')]

    # Safe mode
    elif opt == '--continue-on-error':
        CONTINUE_ON_ERROR = True
    elif opt in ['--safe-mode']:
        SAFE_MODE = int(arg)
    elif opt in ['--reload']:
        RELOAD = True
        
    # Logging
    elif opt in ['--log']:
        LOGGING_LEVEL = int(arg)
    elif opt in ['--email']:
        EMAIL = arg

    # Parallelization and batching
    elif opt in ['--nprocs']:
        NPROCS = int(arg)
    elif opt in ['--batch']:
        BATCH_SIZE = int(float(arg))
        
    else:
        raise ValueError("Option %s not recognized" % opt)

TM.setLogLevel(LOGGING_LEVEL)

def tstamp_print(msg, *args, **kwargs):
    tstamp = datetime.datetime.fromtimestamp(
        time.time()
    ).strftime('%Y-%m-%d %H:%M:%S')
    print(tstamp + " " + msg, *args, **kwargs)

def filter_tstamp_print(msg, *args, **kwargs):
    if VERBOSE:
        tstamp_print(msg, *args, **kwargs)

def filter_print(*args, **kwargs):
    if VERBOSE:
        print(*args, **kwargs)

def safe_store(data, fname):
    # Backup copy
    if os.path.exists(fname):
        shutil.copyfile(fname, fname + '.bak')
    # Store data
    with open(fname, 'wb') as out_stream:
        pickle.dump(data, out_stream)
    # Remove backup
    try:
        os.remove(fname + '.bak')
    except:
        pass
        
# Check for required arguments
if None in [DIST_FNAME, OUT_FNAME]:
    usage()
    tstamp_print("ERROR: Options --dist and --output must be specified")
    sys.exit(3)
if FORCE and RELOAD:
    usage()
    tstamp_print("ERROR: options clash -f --reload")
    sys.exit(3)
if not RELOAD and None in [stg.QTYPE, stg.QNUM]:
    usage()
    tstamp_print("ERROR: Options --qtype and --qnum must be specified")
    sys.exit(3)
if stg.QTYPE < 3 and isinstance(stg.QNUM,list):
    stg.QNUM = stg.QNUM[0]
if MONITOR_QTYPE is not None and MONITOR_QTYPE < 3:
    MONITOR_QNUM = MONITOR_QNUM[0]
if HYPER_MONITOR_QTYPE is not None and HYPER_MONITOR_QTYPE < 3:
    HYPER_MONITOR_QNUM = HYPER_MONITOR_QNUM[0]
map_descr_list = [MONOTONE, SPAN, BTYPE, ORDER]
if (MAP_DESCR is None and MAP_0_DESCR is not None) or \
   (MAP_DESCR is not None and MAP_0_DESCR is None):
    usage()
    tstamp_print("ERROR: --map-0-descr and --map-descr must be " + \
                 "specified simultaneously.")
    sys.exit(3)
if MAP_DESCR is None and None in map_descr_list:
    usage()
    tstamp_print("ERROR: Either options --mtype, --span, --btype, " + \
                 "--order are specified or option --map-descr is specified")
    sys.exit(3)
elif MAP_DESCR is not None and not all([s is None for s in map_descr_list]):
    usage()
    tstamp_print("ERROR: Either options --mtype, --span, --btype, " + \
                 "--order are specified or option --map-descr is specified")
    sys.exit(3)
if ADAPT not in CLI.AVAIL_ADAPTIVITY:
    usage()
    tstamp_print("ERROR: adaptivity algorithm not recognized")
    sys.exit(3)
if HYPER_ADAPT not in CLI.AVAIL_REGRESSION_ADAPTIVITY:
    usage()
    tstamp_print("ERROR: regression adaptivity algorithm not recognized")
    sys.exit(3)

if os.path.exists(OUT_FNAME):
    if not FORCE:
        if not RELOAD:
            sel = ''
            while sel not in ['y', 'Y', 'n', 'N', 'q']:
                if sys.version_info[0] == 3:
                    sel = input("The file %s already exists. " % OUT_FNAME + \
                                "Do you want to reload it? [Y/n/q] ") or 'y'
                else:
                    sel = raw_input("The file %s already exists. " % OUT_FNAME + \
                                    "Do you want to reload it? [Y/n/q] ") or 'y'
        if RELOAD or sel == 'y' or sel == 'Y':
            tstamp_print("Reloading data...")
            with open(OUT_FNAME, 'rb') as in_stream:
                stg = pickle.load(in_stream)
            RELOAD = True
        elif sel == 'q':
            tstamp_print("Terminating.")
            sys.exit(0)
        else:
            sel = ''
            while sel not in ['y', 'Y', 'n', 'N', 'q']:
                if sys.version_info[0] == 3:
                    sel = input("Do you want to overwrite? [y/N/q] ") or 'n'
                else:
                    sel = raw_input("Do you want to overwrite? [y/N/q] ") or 'n'
            if sel == 'n' or sel == 'N' or sel == 'q':
                tstamp_print("Terminating.")
                sys.exit(0)            

try:
    ##################### DATA LOADING #####################
    if not RELOAD or not os.path.exists(OUT_FNAME):
        # Load target distribution
        with open(DIST_FNAME,'rb') as istr:
            stg.target_distribution = pickle.load(istr)
        dim = stg.target_distribution.dim

        if not isinstance(stg.target_distribution,
                          DECDIST.SequentialHiddenMarkovChainDistribution):
            raise ValueError("The input distribution must be a " + \
                             "SequentialHiddenMarkovChainDistribution")
        if len(stg.target_distribution.pi_list) < 1:
            raise ValueError("The input hidden Markov chain distribution " + \
                             "must perform at least " + \
                             "one transition (Z0,Z1). " + \
                             "If no transition is performed (Z0) use tmap-tm.")

        # Load base distribution
        stg.base_distribution = DIST.StandardNormalDistribution(dim)

        # Build integrator
        stg.integrator = ALGSI.TransportMapsSmoother(
            stg.target_distribution.pi_hyper)

    hdim = stg.target_distribution.hyper_dim
    sdim = stg.target_distribution.state_dim

    # Start pool if necessary
    mpi_pool = None
    if NPROCS > 1:
        mpi_pool = TM.get_mpi_pool()
        mpi_pool.start(NPROCS)
        # mpi_pool.set_log_level(10)

    # Start integration
    nstart = stg.integrator.nsteps + 1
    step_str = "Step %%%dd " % (int(np.floor(np.log10(stg.target_distribution.nsteps)))+1)
    try:
        safe_mode_iter = 0
        for n, pi, ll in zip(
                range(nstart, stg.target_distribution.nsteps),
                stg.target_distribution.pi_list[nstart:],
                stg.target_distribution.ll_list[nstart:] ):
            safe_mode_iter += 1
            filter_tstamp_print(step_str % n, end="")
            if n == 0:
                # Instantiate Transport Map
                if MAP_0_DESCR is not None:
                    if ADAPT in ['sequential','tol-sequential']:
                        usage()
                        tstamp_print("ERROR: option --adapt=sequential is not available with --map-0-descr (yet..)")
                        sys.exit(3)
                    tm_approx = TMXML.load_xml(MAP_0_DESCR)
                else:
                    if MONOTONE == 'linspan':
                        map_constructor = TM.Default_IsotropicMonotonicLinearSpanTriangularTransportMap
                    elif MONOTONE == 'intexp':
                        map_constructor = TM.Default_IsotropicIntegratedExponentialTriangularTransportMap
                    elif MONOTONE == 'intsq':
                        map_constructor = TM.Default_IsotropicIntegratedSquaredTriangularTransportMap
                    else:
                        raise ValueError("Monotone type not recognized (linspan|intexp|intsq)")
                    span = SPAN[:-sdim] if isinstance(SPAN, list) else SPAN
                    btype = BTYPE[:-sdim] if isinstance(BTYPE, list) else BTYPE
                    if ADAPT == 'none':
                        tm_approx = map_constructor(hdim+sdim, ORDER, span=span, btype=btype)
                        logging.info("Number coefficients: %d" % tm_approx.n_coeffs)
                    elif ADAPT in ['sequential','tol-sequential']:
                        tm_approx = [map_constructor(hdim+sdim, o, span=span, btype=btype) for o in range(1,ORDER+1)]
                        n_coeffs = sum( tm.n_coeffs for tm in tm_approx )
                        logging.info("Number coefficients: %d" % n_coeffs )
                # Solution builder
                if ADAPT == 'none':
                    builder = ALGADPT.KullbackLeiblerBuilder()
                elif ADAPT == 'sequential':
                    builder = ALGADPT.SequentialKullbackLeiblerBuilder()
                elif ADAPT == 'tol-sequential':
                    builder = ALGADPT.ToleranceSequentialKullbackLeiblerBuilder()
                # Prepare solution parameters 
                if isinstance(stg.QNUM, list):
                    qnum = stg.QNUM[:hdim+sdim]
                else:
                    qnum = stg.QNUM
                if ADAPT == 'none':
                    nc = tm_approx.n_coeffs
                    BSIZE = (max(1,BATCH_SIZE//nc),
                             max(1,BATCH_SIZE//nc**2),
                             max(1,BATCH_SIZE//nc**3))
                    solve_params = {'qtype': stg.QTYPE, 'qparams': qnum,
                                    'tol': stg.TOL, 'maxit': stg.MAXIT,
                                    'regularization': None if stg.REG is None else stg.REG[0],
                                    'ders': stg.DERS, 'fungrad': stg.FUNGRAD,
                                    'batch_size': BSIZE, 'mpi_pool': mpi_pool}
                elif ADAPT in ['sequential', 'tol-sequential']:
                    if stg.REG is not None and len(stg.REG) == 1:
                        stg.REG = stg.REG * ORDER
                    solve_params = {'solve_params_list': []}
                    for i, tm in enumerate(tm_approx):
                        nc = tm.n_coeffs
                        BSIZE = (max(1,BATCH_SIZE//nc),
                                 max(1,BATCH_SIZE//nc**2),
                                 max(1,BATCH_SIZE//nc**3))
                        solve_params['solve_params_list'].append( {
                            'qtype': stg.QTYPE, 'qparams': qnum,
                            'tol': stg.TOL, 'maxit': stg.MAXIT,
                            'regularization': None if stg.REG is None else stg.REG[i],
                            'ders': stg.DERS, 'fungrad': stg.FUNGRAD,
                            'batch_size': BSIZE, 'mpi_pool': mpi_pool
                        } )
                    if ADAPT == 'tol-sequential':
                        qnum = MONITOR_QNUM[:hdim+sdim] \
                           if isinstance(MONITOR_QNUM, list) \
                              else MONITOR_QNUM
                        solve_params['tol'] = ADAPT_TOL
                        solve_params['var_diag_params'] = {
                            'qtype': MONITOR_QTYPE, 'qparams': qnum,
                            'mpi_pool_tuple': (None, mpi_pool) }
                # Prepare KL-divergence monitoring parameters
                var_diag_convergence_params = None
                if MONITOR_QTYPE is not None:
                    qnum = MONITOR_QNUM[:hdim+sdim] \
                           if isinstance(MONITOR_QNUM, list) \
                              else MONITOR_QNUM
                    var_diag_convergence_params = {
                        'qtype': MONITOR_QTYPE, 'qparams': qnum,
                        'mpi_pool_tuple': (None, mpi_pool) }
                # Assimilate
                stg.integrator.assimilate(
                    pi, ll,
                    tm=tm_approx, solve_params=solve_params,
                    builder=builder,
                    var_diag_convergence_params=var_diag_convergence_params,
                    continue_on_error=CONTINUE_ON_ERROR)
            else:
                # Instantiate Transport Map
                if MAP_DESCR is not None:
                    if ADAPT == ['sequential','tol-sequential']:
                        usage()
                        tstamp_print("ERROR: option --adapt=sequential is not available with --map-0-descr (yet..)")
                        sys.exit(3)
                    tm_approx = TMXML.load_xml(MAP_DESCR)
                else:
                    if MONOTONE == 'linspan':
                        map_constructor = TM.Default_IsotropicMonotonicLinearSpanTriangularTransportMap
                    elif MONOTONE == 'intexp':
                        map_constructor = TM.Default_IsotropicIntegratedExponentialTriangularTransportMap
                    elif MONOTONE == 'intsq':
                        map_constructor = TM.Default_IsotropicIntegratedSquaredTriangularTransportMap
                    else:
                        raise ValueError("Monotone type not recognized (linspan|intexp|intsq)")
                    if ADAPT == 'none':
                        tm_approx = map_constructor(hdim+2*sdim, ORDER, span=SPAN, btype=BTYPE)
                        logging.info("Number coefficients: %d" % tm_approx.n_coeffs)
                    elif ADAPT in ['sequential', 'tol-sequential']:
                        tm_approx = [map_constructor(hdim+2*sdim, o, span=SPAN, btype=BTYPE)
                                     for o in range(1,ORDER+1)]
                        n_coeffs = sum( tm.n_coeffs for tm in tm_approx )
                        logging.info("Number coefficients: %d" % n_coeffs )
                # Solution builder
                if ADAPT == 'none':
                    builder = ALGADPT.KullbackLeiblerBuilder()
                elif ADAPT == 'sequential':
                    builder = ALGADPT.SequentialKullbackLeiblerBuilder()
                elif ADAPT == 'tol-sequential':
                    builder = ALGADPT.ToleranceSequentialKullbackLeiblerBuilder()
                # Prepare solution parameters
                if ADAPT == 'none':
                    nc = tm_approx.n_coeffs
                    BSIZE = (max(1,BATCH_SIZE//nc),
                             max(1,BATCH_SIZE//nc**2),
                             max(1,BATCH_SIZE//nc**3))
                    solve_params = {
                        'qtype': stg.QTYPE, 'qparams': stg.QNUM,
                        'tol': stg.TOL, 'maxit': stg.MAXIT,
                        'regularization': None if stg.REG is None else stg.REG[0],
                        'ders': stg.DERS, 'fungrad': stg.FUNGRAD,
                        'batch_size': BSIZE, 'mpi_pool': mpi_pool
                    }
                elif ADAPT in ['sequential','tol-sequential']:
                    if stg.REG is not None and len(stg.REG) == 1:
                        stg.REG = stg.REG * ORDER
                    solve_params = {'solve_params_list': []}
                    for i, tm in enumerate(tm_approx):
                        nc = tm.n_coeffs
                        BSIZE = (max(1,BATCH_SIZE//nc),
                                 max(1,BATCH_SIZE//nc**2),
                                 max(1,BATCH_SIZE//nc**3))
                        solve_params['solve_params_list'].append( {
                            'qtype': stg.QTYPE, 'qparams': stg.QNUM,
                            'tol': stg.TOL, 'maxit': stg.MAXIT,
                            'regularization': None if stg.REG is None else stg.REG[i],
                            'ders': stg.DERS, 'fungrad': stg.FUNGRAD,
                            'batch_size': BSIZE, 'mpi_pool': mpi_pool
                        } )
                    if ADAPT == 'tol-sequential':
                        solve_params['tol'] = ADAPT_TOL
                        solve_params['var_diag_params'] = {
                            'qtype': MONITOR_QTYPE, 'qparams': MONITOR_QNUM,
                            'mpi_pool_tuple': (None, mpi_pool) }
                # Prepare KL-divergence monitoring parameters
                var_diag_convergence_params = None
                if MONITOR_QTYPE is not None:
                    var_diag_convergence_params = {
                        'qtype': MONITOR_QTYPE, 'qparams': MONITOR_QNUM,
                        'mpi_pool_tuple': (None, mpi_pool) }
                # Instantiate Hyper-parameters transport map if necessary
                hyper_tm_approx = None
                if MAP_DESCR is not None:
                    if HYPER_ADAPT == 'tol-sequential':
                        usage()
                        tstamp_print("ERORR: option --hyper-adapt=tol-sequential is not available with --hyper-map-descr (yet..)")
                        sys.exit(3)
                    hyper_tm_approx = TMXML.load_xml(HYPER_MAP_DESCR)
                elif HYPER_MONOTONE is not None:
                    if HYPER_MONOTONE == 'linspan':
                        hyper_map_constructor = TM.Default_IsotropicMonotonicLinearSpanTriangularTransportMap
                    elif HYPER_MONOTONE == 'intexp':
                        hyper_map_constructor = TM.Default_IsotropicIntegratedExponentialTriangularTransportMap
                    elif HYPER_MONOTONE == 'intsq':
                        hyper_map_constructor = TM.Default_IsotropicIntegratedSquaredTriangularTransportMap
                    else:
                        raise ValueError("Monotone type not recognized (linspan|intexp)")
                    if HYPER_ADAPT == 'none':
                        hyper_tm_approx = map_constructor(hdim, HYPER_ORDER, span=HYPER_SPAN, btype=HYPER_BTYPE)
                    elif HYPER_ADAPT == 'tol-sequential':
                        hyper_tm_approx = [map_constructor(hdim, o, span=HYPER_SPAN, btype=HYPER_BTYPE)
                                           for o in range(1,HYPER_ORDER+1)]
                # Prepare regression builder and parameters if necessary
                reg_params = None
                reg_builder = None
                regression_convergence_params = None
                if hyper_tm_approx is not None:
                    rho_hyper = DIST.StandardNormalDistribution(hdim)
                    if HYPER_ADAPT == 'none':
                        reg_builder = ALGADPT.L2RegressionBuilder()
                        reg_params = {'d': rho_hyper,
                                      'qtype': stg.HYPER_QTYPE, 'qparams': stg.HYPER_QNUM,
                                      'tol': stg.HYPER_TOL, 'maxit': stg.HYPER_MAXIT,
                                      'regularization': stg.HYPER_REG}
                    elif HYPER_ADAPT == 'tol-sequential':
                        reg_builder = ALGADPT.ToleranceSequentialL2RegressionBuilder()
                        reg_params = {
                            'eps': HYPER_ADAPT_TOL,
                            'regression_params_list': [
                                {'d': rho_hyper,
                                 'qtype': stg.HYPER_QTYPE, 'qparams': stg.HYPER_QNUM,
                                 'tol': stg.HYPER_TOL, 'maxit': stg.HYPER_MAXIT,
                                 'regularization': stg.HYPER_REG}] * HYPER_ORDER,
                            'monitor_params': {
                                'd': rho_hyper,
                                'qtype': HYPER_MONITOR_QTYPE,
                                'qparams': HYPER_MONITOR_QNUM }
                            }
                    # Prepare regression monitoring parameters
                    if HYPER_MONITOR_QTYPE is not None:
                        regression_convergence_params = {
                            'd': rho_hyper,
                            'qtype': HYPER_MONITOR_QTYPE, 'qparams': HYPER_MONITOR_QNUM }
                # Assimilate
                stg.integrator.assimilate(
                    pi, ll,
                    tm=tm_approx, solve_params=solve_params,
                    builder=builder,
                    hyper_tm=hyper_tm_approx, regression_params=reg_params,
                    hyper_builder=reg_builder,
                    var_diag_convergence_params=var_diag_convergence_params,
                    regression_convergence_params=regression_convergence_params,
                    continue_on_error=CONTINUE_ON_ERROR)
            if MONITOR_QTYPE is not None:
                filter_print("  Var.Diag.=%.2e" % stg.integrator.var_diag_convergence[-1], end="")
            if HYPER_MONITOR_QTYPE is not None and n > 1:
                filter_print("  Reg.Diag.=%.2e" % stg.integrator.regression_convergence[-1], end="")
            filter_print("    [DONE]")
            if safe_mode_iter == SAFE_MODE: # Store partial integration
                filter_tstamp_print("Storing...")
                safe_store(stg, OUT_FNAME)
                safe_mode_iter = 0
    finally:
        if mpi_pool is not None:
            mpi_pool.stop()

    # Collect the smoothing map
    stg.tmap = stg.integrator.smoothing_map

    stg.approx_base_distribution = DIST.PullBackTransportMapDistribution(stg.tmap, stg.target_distribution)
    stg.approx_target_distribution = DIST.PushForwardTransportMapDistribution(stg.tmap, stg.base_distribution)

    # STORE
    filter_tstamp_print("Storing...")
    safe_store(stg, OUT_FNAME)

    if EMAIL is not None:
        sendmail(EMAIL, "tmap-sequential-tm - termination", "Success")

except Exception as e:
    if EMAIL is not None:
        sendmail(EMAIL, "tmap-sequential-tm - error", str(e))
    raise e
        
finally:
    if INTERACTIVE:
        from IPython import embed
        embed()
