#!/usr/bin/env python3

"""
Visualization of electron density calculated with qdyn_laser of
quantum_dynamics package.
"""

import argparse
parser = argparse.ArgumentParser(
    description="""Plots the time-evolution of the electron density as a density map.
By default, the plot is showed in an interactive matplotlib window, but saving
working completely without a gui is supported by additional command line
flags.""")


parser.add_argument("--datafile", "-f", type=str,
                    required=True, help="Path to the datafile.")

parser.add_argument("--nogui", action="store_true", default=False,
                    help="""Disables the interactive matplotlib window and only
saves the figure. Remember to also set the output filename.""")
parser.add_argument("-o", "--out", type=str, default="",
                    help="""Sets the output filename and saves the figure
(also) in a file""", dest="outputfile")

args = parser.parse_args()


if args.nogui:
    import matplotlib
    matplotlib.use('Agg')
    assert args.ouputfile != "", "You forgot the give me the output filename"
else:
    import matplotlib
    matplotlib.use('TkAgg')

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
import h5py


def get_figure(width=246, aspect=(np.sqrt(5.0) + 1) / 2.0):
    """
    Returns a matplotlib figure. Default values produces figure whose
    width is 246 pt, i.e, the PRL column width and aspect ~ golden ratio.

    Parameters
    ----------
    width       int/float, optional
                width of the figure in 'pt' units (1 pt ≈ 0.3528 mm)
                Defaults to 246 pt, i.e., the width of a single column in
                Physical Review Letters
    aspect      int/float
                Sets width/height to this value

    Returns
    -------
    fig         matplotlib.pyplot.Figure
    """
    assert isinstance(width, (int, float, np.floating)), \
        "type(width) is not int,float or np.floating [was %s]" % repr(
            type(width))
    assert isinstance(aspect, (int, float, np.floating)), \
        "type(aspect) is not int,float or np.floating [was %s]" % repr(
            type(aspect))

    width = width / 72.  # in inches
    height = width / aspect

    return plt.figure(figsize=(width, height))


# Load the datafile
f = h5py.File(args.datafile, 'r')

# Prepare the canvas and axes
fig = get_figure(512)
ax = fig.add_subplot(111)

# Plot the data
coordinate_grid = f["coordinate_grid"][:]
times = f["savetimes"][:]
T, X = np.meshgrid(times, coordinate_grid)
N = f["wavefunction"][:]

density = np.abs(f["wavefunction"])**2

p = ax.imshow(density, origin='bottom',
              aspect='auto', interpolation='gaussian',
              norm=LogNorm(vmin=1e-5, vmax=0.5 * density.max()), extent=[times.min(),
                                                                         times.max(),
                                                                         coordinate_grid.min(),
                                                                         coordinate_grid.max()],
              cmap='magma_r')

# Plot the laser electric field
laser = f["laser"]
M = coordinate_grid.max() / np.abs(laser[:, 1]).max() * 0.9
ax.plot(laser[:, 0], laser[:, 1] * M, c='k', lw=0.5)
ax.set_xlim(laser[:, 0].min(), laser[:, 0].max())
ax.set_ylim(coordinate_grid.min(), coordinate_grid.max())

# Colorbar
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(p, cax=cax)

# Axis labels
cbar.set_label(r'Electron density')
ax.set_xlabel(r'time (a.u.)')
ax.set_ylabel(r'$x$-coordinate (a.u.)')

fig.tight_layout(pad=1)

if args.outputfile != "":
    plt.savefig(args.outputfile, dpi=200)

if not args.nogui:
    plt.show()
