#!/usr/bin/env python3

"""
A script for running time-dependent simulation
of 1D Hydrogen model in an external laser pulse.
"""

import argparse
import numpy as np
import scipy.sparse as sp
from scipy.integrate import simps
import h5py
import progressbar
from quantum_dynamics.tise import get_initial_state
from quantum_dynamics.tdse import evolve_timestep, complex_absorbing_potential
from quantum_dynamics.utils import save_sparse_matrix


def potential(x):
    """1D Hydrogen model potential"""
    return -1.0 / np.sqrt(x**2 + 1)


def step(x):
    """Heaviside step function"""
    return 0.5 * (1 + np.sign(x))


def laser(t, A, w, T):
    """Laser electric field at time t."""
    return A * np.sin(np.pi / T * t)**2 * np.cos(w * t) * step(t) * step(T - t)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="""Simulation of TDSE for 1D
                                     Hydrogen electron under laser electric
                                     field in the dipole approximation""")

    parser.add_argument('--delta-t', '-dt', type=float, default=5e-1,
                        help="Time-step")
    parser.add_argument('--max-t', '-T', type=float, default=1000,
                        help="Maximum simulated time")
    parser.add_argument('--grid-length', '-L', type=float, default=500,
                        help="Length of coordinate grid: [-L/2,L/2]")
    parser.add_argument('--grid-pts', '-N', type=int, default=2000,
                        help="Number of gridpoints")
    parser.add_argument('--save-interval', type=int, default=1,
                        help="How many time-steps to skip between saves")
    parser.add_argument('--expm-tolerance', type=float, default=1e-5,
                        help="Relative error for the operation of the matrix exponential")
    parser.add_argument('--krylov-subspace-maxdim', type=int, default=50,
                        help="Maximum dimension of the Krylov Subspace to use \
                        (warns if propagator not converged within this)")
    parser.add_argument('--pulse-amplitude', type=float, default=0.08,
                        help="Amplitude of the laser electric field in atomic units")
    parser.add_argument('--pulse-frequency', type=float, default=0.0569,
                        help="Carrier frequency of the laser electric field")
    parser.add_argument('--pulse-duration', type=float, default=1000,
                        help="Duration of the laser electric field (sin^2\
                        envelope)")
    parser.add_argument('--savefile', type=str, default="qdyn.h5",
                        help="Path where to save the simulation data")
    parser.add_argument('--cap-width', type=float, default=10,
                        help="Width of the complex absorbing potential")
    parser.add_argument('--cap-height', type=float, default=0.5,
                        help="Height of the complex absorbing potential")
    parser.add_argument('--sparse-matrix-format', type=str, default='dia',
                        help="Type of the sparse matrices to use ( 'bsr',\
                        'coo','csc', 'csr', 'dia')")

    args = parser.parse_args()

    delta_t = args.delta_t
    max_t = args.max_t
    grid_length = args.grid_length
    grid_pts = args.grid_pts
    save_every_nth_timestep = args.save_interval
    expm_tolerance = args.expm_tolerance
    krylov_subspace_maxdim = args.krylov_subspace_maxdim
    pulse_amplitude = args.pulse_amplitude
    pulse_frequency = args.pulse_frequency
    pulse_duration = args.pulse_duration
    cap_width = args.cap_width
    cap_height = args.cap_height
    sparse_matrix_format = args.sparse_matrix_format.lower()

    # Set up the coordinate grid
    coordinate_grid = np.linspace(-grid_length / 2, grid_length / 2, grid_pts)

    # Get the initial state and time-independent part of the Hamiltonian
    psi, H0 = get_initial_state(coordinate_grid, potential,
                                sparse_matrix_format)

    # Add the complex absorbing potential
    cap = complex_absorbing_potential(coordinate_grid, cap_width, cap_height)

    # Note: scipy changes the sparse matrix type as it wishes
    Hti = H0 + sp.diags(cap, 0, format=sparse_matrix_format)

    # Initialize the main time propagation loop variables
    times = np.arange(0, max_t + delta_t / 2.0, delta_t)
    num_timesteps_saved = int(np.ceil(len(times) / save_every_nth_timestep))
    iteration = 0

    # Initialize the datafile
    savefile = h5py.File(args.savefile, 'w')
    wf_dataset = savefile.create_dataset("wavefunction", shape=(grid_pts, num_timesteps_saved),
                                         dtype='complex', chunks=(grid_pts, 1),
                                         compression="gzip", compression_opts=3
                                         )
    save_sparse_matrix(savefile, "tise_hamiltonian", H0)

    # Setup a progressbar
    bar = progressbar.ProgressBar(redirect_stdout=True,
                                  widgets=[
                                      progressbar.Bar(), '  ',
                                      progressbar.DynamicMessage('wfnorm'),
                                      '   ',
                                      '[', progressbar.AdaptiveETA(), ']'])
    # Main propagation loop
    for t in bar(times):

        # Calculate H(t + Δt/2)
        Hmid = Hti + sp.diags(coordinate_grid
                              * laser(t + delta_t / 2.0, pulse_amplitude,
                                      pulse_frequency, pulse_duration), 0,
                              format=sparse_matrix_format)

        # ψ(x, t + Δt) = U( t + Δt, t) ψ(x, t)
        psi = evolve_timestep(psi, Hmid, delta_t,
                              expm_tolerance, krylov_subspace_maxdim)

        # Save data
        if iteration % save_every_nth_timestep == 0:
            # Calculate and output norm
            norm = simps(np.abs(psi)**2, coordinate_grid)
            bar.update(wfnorm=norm)
            wf_dataset[:, iteration // save_every_nth_timestep] = psi

        iteration += 1

    # Save ranges for plotting
    savefile["coordinate_grid"] = coordinate_grid
    savefile["savetimes"] = times[:: save_every_nth_timestep]
    savefile["final_wavefunction"] = psi
    # Save laser
    savefile["laser"] = np.vstack((times, laser(times, pulse_amplitude,
                                                pulse_frequency, pulse_duration))).T

    # Write a human-readable description of the savefile
    savefile["/"].attrs["description"] = """This HDF5-file contains
a simulation of time-dependent Schrödinger equation of an electron in 1D
hydrogen under laser-pulse. 

You can find the coordinate grid in '/coordinate_grid'
and the time-steps where we saved data in '/savetimes'. 
The laser electric field is stored in '/laser' so that the first column corresponds 
to time and the second column to laser electric field.

The final values of the wavefunction at the end of the simulation are saved in
`/final_wavefunction` and the wavefunction values at the savetimes in
`/wavefunction`.

Time-independent part of the Hamiltonian is saved in `/tise_hamiltonian`. For
details on how to load it, please see
https://pypi.org/project/quantum-dynamics/.

The data has been calculated using the 'qdyn_laser' package developed during
the computational physics course at TUT in spring 2018."""

    # Save datafile
    savefile.close()
