#!/usr/bin/env python3
""" The spiritual successor to ParmDBplot for quickly reviewing gain solutions generated by NDPPP.
"""
__version__ = 'v2.6.0'
import logging
import signal
import sys

from PyQt5.QtWidgets import QApplication, QComboBox, QDialog, QFormLayout, QGridLayout, QHBoxLayout, QLabel, \
    QListWidget, QPushButton, QVBoxLayout, QWidget
from PyQt5 import QtCore

from losoto.lib_operations import reorderAxes
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure

import losoto.h5parm as lh5
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
matplotlib.use('Qt5Agg')
plt.ion()

signal.signal(signal.SIGINT, signal.SIG_DFL)


def load_axes(vals, st, axis_type, antenna, refantenna, timeslot=0, freqslot=0, direction=0, weights=None):
    """ Load an abscissa and ordinate from the H5Parm.

    Args:
        vals (ndarray): raw soltab values to load.
        st_type (str): string describing the type of solutions (e.g. phase, clock, amplitude).
        axis_type (str): `time` or `freq`.
        antenna (str): name of the antenna to select. Automatically determined if set to `None`.
        refant (int): index of the reference antenna.
        timeslot (int): timeslot to load.
        freqslot (int): frequency slot to load.
        direction (int): direction to load.
        weights (ndarray or None): weights for soltab values.
    Returns:
        xaxis (ndarray): an absicssa to plot.
        yaxis (ndarray): an ordinate to plot.
        yaxis_weight (ndarray): weights for the ordinate to plot.
        plabels (list): a list of labels for each plot (e.g. different polarizations).
        isphase (bool): boolean indicating whether or not the quantity is a phase.

        OR

        errorcode (str): an error message if things went wrong.
    """
    wrapphase = True
    # Values have shape (timestamps, frequencies, antennas, polarizations, directions).
    axes = st.getAxesNames()
    st_type = st.getType()
    x_axis = vals[1][axis_type]
    values = vals[0]
    plabels = []
    isphase = False
    if weights is None:
        weights = np.ones(values.shape)

    if axis_type == 'time':
        if ('rotationmeasure' in st.name) or ('faraday' in st.name) or ('tec' in st.name and 'freq' not in axes and 'dir' not in axes):
            y_axis = values[:, antenna] - values[:, refantenna]
            y_axis_weight = weights[:, antenna]
            Y_AXIS = y_axis
            Y_AXIS_WEIGHT = y_axis_weight
        elif ('pol' in axes) and ('dir' in axes):
            if st_type == 'phase':
                isphase = True
                # Plot phase-like quantities w.r.t. to a reference antenna.
                y_axis = values[:, freqslot, antenna, :, direction] - values[:, freqslot, refantenna, :, direction]
                y_axis_weight = weights[:, freqslot, antenna, :, direction]
                if wrapphase:
                    y_axis = wrap_phase(y_axis)
            elif (st_type == 'clock') or (st_type == 'rotationmeasure') or ((st_type == 'tec' or st_type == 'phase') and 'freq' not in axes):
                y_axis = values[:, antenna, direction] - values[:, refantenna, direction]
                y_axis_weight = weights[:, antenna, direction]
            else:
                y_axis = values[:, freqslot, antenna, :, direction]
                y_axis_weight = weights[:, freqslot, antenna, :, direction]
            Y_AXIS = []
            Y_AXIS_WEIGHT = []
            plabels = []
            # Iterate over polarizations.
            for i in range(y_axis.shape[1]):
                Y_AXIS.append(y_axis[:, i])
                Y_AXIS_WEIGHT.append(y_axis_weight[:, i])
                plabels.append(vals[1]['pol'][i])
        elif 'pol' in axes:
            if st_type == 'phase':
                isphase = True
                # Plot phase-like quantities w.r.t. to a reference antenna.
                y_axis = values[:, freqslot, antenna, :] - values[:, freqslot, refantenna, :]
                y_axis_weight = weights[:, freqslot, antenna, :]
                if wrapphase:
                    y_axis = wrap_phase(y_axis)
            elif (st_type == 'clock') or (st_type == 'rotationmeasure') or (st_type == 'tec') or (st_type == 'phase' and 'freq' not in axes):
                y_axis = values[:, antenna] - values[:, refantenna]
                y_axis_weight = weights[:, antenna]
            else:
                y_axis = values[:, freqslot, antenna, :]
                y_axis_weight = weights[:, freqslot, antenna, :]
            Y_AXIS = []
            Y_AXIS_WEIGHT = []
            plabels = []
            for i in range(y_axis.shape[1]):
                Y_AXIS.append(y_axis[:, i])
                Y_AXIS_WEIGHT.append(y_axis_weight[:, i])
                plabels.append(vals[1]['pol'][i])
        elif 'dir' in axes:
            if st_type == 'phase':
                isphase = True
                # Plot phase-like quantities w.r.t. to a reference antenna.
                y_axis = values[:, freqslot, antenna, direction] - values[:, freqslot, refantenna, direction]
                y_axis_weight = weights[:, freqslot, antenna, direction]
                if wrapphase:
                    y_axis = wrap_phase(y_axis)
            elif (st_type == 'clock') or (st_type == 'rotationmeasure') or ((st_type == 'tec' or st_type == 'phase') and 'freq' not in axes):
                y_axis = values[:, antenna, direction] - values[:, refantenna, direction]
                y_axis_weight = weights[:, antenna, direction]
            else:
                y_axis = values[:, freqslot, antenna, direction]
                y_axis_weight = weights[:, freqslot, antenna, direction]
            Y_AXIS = y_axis
            Y_AXIS_WEIGHT = y_axis_weight
        elif ('pol' not in axes) and ('dir' not in axes):
            if (st_type == 'clock') or (st_type == 'rotationmeasure') or ((st_type == 'tec' or st_type == 'phase') and 'freq' not in axes):
                y_axis = values[:, antenna] - values[:, refantenna]
                y_axis_weight = weights[:, antenna]
            elif ((st_type == 'tec' or st_type == 'phase') and 'freq' in axes):
                y_axis = values[:, 0, antenna] - values[:, 0, refantenna]
                y_axis_weight = weights[:, 0, antenna]
            else:
                # Assume normal time, freq, antenna ordering.
                y_axis = values[:, freqslot, antenna]
                y_axis_weight = weights[:, freqslot, antenna]
            Y_AXIS = y_axis
            Y_AXIS_WEIGHT = y_axis_weight
    elif axis_type == 'freq':
        if ('rotationmeasure' in st.name) or ('clock' in st.name) or ('faraday' in st.name) or ('tec' in st.name):
            logging.warning('Rotation Measure does not support frequency axis! Switch to time instead.')
        if ('pol' in axes) and ('dir' in axes):
            if st_type == 'phase':
                isphase = True
                # Plot phase-like quantities w.r.t. to a reference antenna.
                y_axis = values[timeslot, :, antenna, :, direction] - values[timeslot, :, refantenna, :, direction]
                y_axis_weight = weights[timeslot, :, antenna, :, direction]
                if wrapphase:
                    y_axis = wrap_phase(y_axis)
            else:
                y_axis = values[timeslot, :, antenna, :, direction]
                y_axis_weight = weights[timeslot, :, antenna, :, direction]
            Y_AXIS = []
            Y_AXIS_WEIGHT = []
            plabels = []
            for i in range(y_axis.shape[1]):
                Y_AXIS.append(y_axis[:, i])
                Y_AXIS_WEIGHT.append(y_axis_weight[:, i])
                plabels.append(vals[1]['pol'][i])
        elif 'pol' in axes:
            if st_type == 'phase':
                isphase = True
                # Plot phase-like quantities w.r.t. to a reference antenna.
                y_axis = values[timeslot, :, antenna, :] - values[timeslot, :, refantenna, :]
                y_axis_weight = weights[timeslot, :, antenna, :]
                if wrapphase:
                    y_axis = wrap_phase(y_axis)
            else:
                y_axis = values[timeslot, :, antenna, :]
                y_axis_weight = weights[timeslot, :, antenna, :]
            Y_AXIS = []
            Y_AXIS_WEIGHT = []
            plabels = []
            for i in range(y_axis.shape[1]):
                Y_AXIS.append(y_axis[:, i])
                Y_AXIS_WEIGHT.append(y_axis_weight[:, i])
                plabels.append(vals[1]['pol'][i])
        elif 'dir' in axes:
            if st_type == 'phase':
                isphase = True
                # Plot phase-like quantities w.r.t. to a reference antenna.
                y_axis = values[timeslot, :, antenna, direction] - values[timeslot, :, refantenna, direction]
                y_axis_weight = weights[timeslot, :, antenna, direction]
                if wrapphase:
                    y_axis = wrap_phase(y_axis)
            else:
                y_axis = values[timeslot, :, antenna, direction]
                y_axis_weight = weights[timeslot, :, antenna, direction]
            Y_AXIS = y_axis
            Y_AXIS_WEIGHT = y_axis_weight
        elif ('pol' not in axes) and ('dir' not in axes):
            y_axis = values[timeslot, :, antenna]
            y_axis_weight = weights[timeslot, :, antenna]
            Y_AXIS = y_axis
            Y_AXIS_WEIGHT = y_axis_weight
    if len(plabels) == 0:
        plabels = ['', '', '', '']
    return x_axis, Y_AXIS, Y_AXIS_WEIGHT, plabels, isphase


def load_axes_2d(vals, weights, st, antenna, refantenna, pol=0, direction=0):
    """ Load a 2D slice from the H5Parm.

    Args:
        vals (ndarray): raw soltab values to load as obtained from Soltab.getValues().
        weights (ndarray): weights corresponding to vals.
        st (soltab): soltab object.
        antenna (str): name of the antenna to select. Automatically determined if set to `None`.
        refant (int): index of the reference antenna.
        pol (int): polarization to load.
        direction (int): direction to load.
    Returns:
        x_axis (ndarray): time axis
        y_axis (ndarray): frequency axis
        plotvals (ndarray): 2d ndarray of soltab values
        isphase (bool): boolean indicating whether or not the quantity is a phase.

        OR

        errorcode (str): an error message if things went wrong.
    """
    wrapphase = True
    # Values have shape (timestamps, frequencies, antennas, polarizations, directions).
    axes = st.getAxesNames()
    st_type = st.getType()
    x_axis = vals[1]['time']
    y_axis = vals[1]['freq']
    values = np.asarray(vals[0])
    plabels = []
    try:
        pols = list(vals[1]['pol'])
    except:
        logging.debug('No polarization axis present.')
    if (st_type == 'phase') or (st_type == 'rotation'):
        isphase = True
    else:
        isphase = False
    if isphase:
        if 'polalign' in st.name:
            # Special case. We must plot XX-YY and XX is 0 so not informative anyway.
            # There's probably a nicer way of doing this.
            # This by definition has polarizations, no need to check.
            if 'dir' in axes:
                plotvals = (values[:, :, antenna, pols.index('XX'), direction] - values[:, :, antenna, pols.index('YY'), direction]) - (values[:, :, refantenna, pols.index('XX'), direction] - values[:, :, refantenna, pols.index('YY'), direction])
                plotvals_weight = weights[:, :, antenna, pol, direction]
            else:
                plotvals = (values[:, :, antenna, pols.index('XX')] - values[:, :, antenna, pols.index('YY')]) - (values[:, :, refantenna, pols.index('XX')] - values[:, :, refantenna, pols.index('YY')])
                plotvals_weight = weights[:, :, antenna, pol]
        else:
            if 'pol' in axes:
                if 'dir' in axes:
                    plotvals = values[:, :, antenna, pol, direction] - values[:, :, refantenna, pol, direction]
                    plotvals_weight = weights[:, :, antenna, pol, direction]
                else:
                    plotvals = values[:, :, antenna, pol] - values[:, :, refantenna, pol]
                    plotvals_weight = weights[:, :, antenna, pol]
            elif 'dir' in axes:
                plotvals = values[:, :, antenna, direction] - values[:, :, refantenna, direction]
                plotvals_weight = weights[:, :, antenna, direction]
            else:
                plotvals = values[:, :, antenna] - values[:, :, refantenna]
                plotvals_weight = weights[:, :, antenna]
        if wrapphase:
            plotvals = wrap_phase(plotvals)
    else:
        if 'pol' in axes:
            if 'dir' in axes:
                plotvals = values[:, :, antenna, pol, direction]
                plotvals_weight = values[:, :, antenna, pol, direction]
            else:
                plotvals = values[:, :, antenna, pol]
                plotvals_weight = weights[:, :, antenna, pol]
        elif 'dir' in axes:
            plotvals = values[:, :, antenna, direction]
            plotvals_weight = weights[:, :, antenna, direction]
        else:
            plotvals = values[:, :, antenna] - values[:, :, refantenna]
            plotvals_weight = weights[:, :, antenna]

    return x_axis, y_axis, plotvals, plotvals_weight, isphase


class GraphWindow(QDialog):
    """ A window displaying the plotted quantity. Allows the user to cycle through time or frequency.
    """
    def __init__(self, values, weights, frametitle, antindex, refantindex, axis, st, timeslot=0, freqslot=0, direction=0, times=None, freqs=None, parent=None):
        """ Initialize a new GraphWindow instance.

        Args:
            frametitle (str): title the frame will hvae.
            antindex (int): the index of the selected antenna.
            axis (str): the type of axis being plotted (time or freq).
            timeslot (int): index along the time axis to start with.
            freqslot (int): index along the frequency axis to start with.
            direction (str): name of the direction to plot.
            parent (QDialog): parent window instance.
        Returns:
            None
        """
        super(GraphWindow, self).__init__()
        self.setWindowFlags(QtCore.Qt.WindowSystemMenuHint | QtCore.Qt.WindowMinMaxButtonsHint | QtCore.Qt.WindowCloseButtonHint)
        # Set up for logging output.
        self.LOGGER = logging.getLogger('GraphWindow')
        # LOGGER.setLevel(logging.INFO)
        self.LOGGER.setLevel(logging.DEBUG)

        self.frametitle = frametitle
        self.axis = axis
        self.timeslot = 0
        self.freqslot = 0
        self.direction = direction
        self.values = values
        self.weights = weights
        self.antindex = antindex
        self.refantindex = refantindex
        self.st = st
        self.parent = parent
        try:
            self.frequencies = freqs
        except AttributeError:
            # frequencies is None, plotting against time.
            pass

        try:
            self.times = times
            self.times -= self.times[0]
        except AttributeError:
            # times is None, plotting against time.
            pass

        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        self.LOGFILEH = logging.FileHandler('h5plot.log')
        self.LOGFILEH.setLevel(logging.DEBUG)
        self.LOGFILEH.setFormatter(formatter)
        self.LOGGER.addHandler(LOGFILEH)

        self.setWindowTitle(frametitle)

        self.button_next = QPushButton('Forward')
        self.button_next.clicked.connect(self._forward_button_event)
        self.button_prev = QPushButton('Back')
        self.button_prev.clicked.connect(self._backward_button_event)
        self.button_prev.setEnabled(False)
        if 'time' in axis.lower():
            try:
                self.select_label = QLabel('Freq slot {:.2f} MHz'.format(self.frequencies[freqslot] / 1e6))
                if len(self.frequencies == 1):
                    self.button_next.setEnabled(False)
                else:
                    self.button_next.setEnabled(True)
            except TypeError:
                # No frequency axis.
                self.select_label = QLabel('')
                self.button_next.setEnabled(False)
                self.button_prev.setEnabled(False)
        elif 'freq' in axis.lower():
            try:
                self.select_label = QLabel('Time: ' + self.format_time(timeslot))
                if len(self.times == 1):
                    self.button_next.setEnabled(False)
                else:
                    self.button_next.setEnabled(True)
            except TypeError:
                # No time axis.
                self.select_label = QLabel('')
                self.button_next.setEnabled(False)
                self.button_prev.setEnabled(False)
        self.select_label.setAlignment(QtCore.Qt.AlignCenter)

        self.btn_antiter_next = QPushButton('Next antenna')
        self.btn_antiter_next.clicked.connect(self._antiter_next_button_event)
        self.btn_antiter_prev = QPushButton('Previous antenna')
        self.btn_antiter_prev.clicked.connect(self._antiter_prev_button_event)

        antiter_widget = QWidget()
        antiter_layout = QHBoxLayout(antiter_widget)
        antiter_layout.addWidget(self.btn_antiter_prev)
        antiter_layout.addWidget(self.btn_antiter_next)

        self.buttons = QGridLayout()
        self.buttons.addWidget(self.button_prev, 0, 0)
        self.buttons.addWidget(self.select_label, 0, 1)
        self.buttons.addWidget(self.button_next, 0, 2)

        self.fig = Figure()
        self.canvas = FigureCanvas(self.fig)
        self.toolbar = NavigationToolbar(self.canvas, self)

        self.layout = QVBoxLayout()
        self.layout.addWidget(self.toolbar)
        self.layout.addWidget(self.canvas)
        self.layout.addItem(self.buttons)
        self.layout.addWidget(antiter_widget)
        self.setLayout(self.layout)

    def format_time(self, seconds):
        """ Formats the time to be displayed in the plotting windows.

        A string is formatted, displaying the time in seconds or (fractional) minutes or hours.

        Args:
            seconds (int): the time in seconds.
        Returns:
            formatted time (str): formatted time string.
        """
        if seconds < 60:
            return '{:.3f} sec'.format(seconds)
        elif 60 <= seconds < 3600:
            return '{:.3f} min'.format(seconds / 60)
        elif seconds >= 3600:
            return '{:.3f} hr'.format(seconds / 3600)
        else:
            return '{:.3f}'.format(seconds)

    def _antiter_next_button_event(self):
        if (self.antindex + 1) < len(self.parent.stations):
            self.antindex += 1
        else:
            self.antindex = 0
        diridx = self.parent.directions.index(self.direction)
        x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, freqslot=self.freqslot, direction=diridx)
        self.frametitle = self.parent.stations[self.antindex]
        self.plot(x, y, yw, self.frametitle, ax_labels=[self.xlabel, self.ylabel], plot_labels=l, isphase=p)
        self.setWindowTitle(self.frametitle)

    def _antiter_prev_button_event(self):
        if (self.antindex - 1) > 0:
            self.antindex -= 1
        else:
            self.antindex = len(self.parent.stations) - 1
        diridx = self.parent.directions.index(self.direction)
        x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, freqslot=self.freqslot, direction=diridx)
        self.frametitle = self.parent.stations[self.antindex]
        self.frametitle = self.parent.stations[self.antindex]
        self.plot(x, y, yw, self.frametitle, ax_labels=[self.xlabel, self.ylabel], plot_labels=l, isphase=p)
        self.setWindowTitle(self.frametitle)

    def _forward_button_event(self):
        """ An event triggered by pressing the "Forward" button of a GraphWindow.

        When pressed, the abscissa is advanced one position, showing the next time or frequency slot.
        """
        if 'time' in self.xlabel.lower():
            self.freqslot += 1
            self.select_label.setText('Frequency: {:.3f} MHz'.format(self.frequencies[self.freqslot] / 1e6))
            diridx = self.parent.directions.index(self.direction)
            x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, freqslot=self.freqslot, direction=diridx)
            if (self.freqslot > 0) and (not self.button_prev.isEnabled()):
                self.button_prev.setEnabled(True)
            if self.freqslot == (len(self.frequencies) - 1):
                self.button_next.setEnabled(False)
        elif 'freq' in self.xlabel.lower():
            self.timeslot += 1
            self.select_label.setText('Time: ' + self.format_time(self.times[self.timeslot]))
            diridx = self.parent.directions.index(self.direction)
            x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, timeslot=self.timeslot, direction=diridx)
            if self.timeslot < (len(self.times) - 1) and (not self.button_prev.isEnabled()):
                self.button_prev.setEnabled(True)
            if self.timeslot == (len(self.times) - 1):
                self.button_next.setEnabled(False)
        self.plot(x, y, yw, self.frametitle, ax_labels=[self.xlabel, self.ylabel], plot_labels=l, isphase=p)

    def _backward_button_event(self):
        """ An event triggered by pressing the "Back" button of a GraphWindow.

        When pressed, the abscissa is set back one position, showing the previous time or frequency slot.
        """
        if 'time' in self.xlabel.lower():
            if self.freqslot > 0:
                self.freqslot -= 1
                self.select_label.setText('Frequency: {:.3f} MHz'.format(self.frequencies[self.freqslot] / 1e6))
                diridx = self.parent.directions.index(self.direction)
                x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, freqslot=self.freqslot, direction=diridx)
                if self.freqslot == 0:
                    self.button_prev.setEnabled(False)
                if (self.freqslot < (len(self.frequencies) - 1)) and (not self.button_next.isEnabled()):
                    self.button_next.setEnabled(True)
        elif 'freq' in self.xlabel.lower():
            if self.timeslot > 0:
                self.timeslot -= 1
                self.select_label.setText('Time: ' + self.format_time(self.times[self.timeslot]))
                diridx = self.parent.directions.index(self.direction)
                x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, timeslot=self.timeslot, direction=diridx)
                if self.timeslot == 0:
                    self.button_prev.setEnabled(False)
                if (self.timeslot < (len(self.parent.times) - 1)) and (not self.button_next.isEnabled()):
                    self.button_next.setEnabled(True)
        self.plot(x, y, yw, self.frametitle, ax_labels=[self.xlabel, self.ylabel], plot_labels=l, isphase=p)

    def plot(self, xaxis, yaxis, yaxis_weight, frametitle='', limits=[None, None], ax_labels=['', ''], plot_labels=[], multidim=False, isphase=False):
        self.xlabel = ax_labels[0]
        self.ylabel = ax_labels[1]
        self.xlabelp = plot_labels[0]
        self.ylabelp = plot_labels[1]
        self.fig.clf()
        self.ax = self.fig.add_subplot(111)
        self.ax.clear()
        if 'time' in ax_labels[0]:
            # Start counting from t=0
            xaxis = xaxis - xaxis[0]
        self.ax.set_title(frametitle + ' - {:s}'.format(self.direction))
        if self.ax.get_legend_handles_labels()[1]:
            self.ax.legend()
        if type(xaxis) is list:
            xaxis = np.asarray(xaxis)
        if type(yaxis) is list:
            yaxis = np.asarray(yaxis)
        if type(yaxis_weight) is list:
            yaxis_weight = np.asarray(yaxis_weight)
            # Set weights to 0 for NaN solutions.
            yaxis_weight[np.isnan(yaxis)] = 0
        if len(yaxis.shape) > 1 and len(plot_labels) != 0:
            for i in range(yaxis.shape[0]):
                v = yaxis[i, :]
                self.ax.plot(xaxis, v, '--', alpha=0.25, color='C' + str(i))
                self.ax.plot(xaxis, v, 'h', label=plot_labels[i], color='C' + str(i))
                # Handle flagged data. Weights are polarization independent.
                vw = yaxis_weight[i, :]
                v_m = np.ma.masked_where(vw == 0, v)
                if np.ma.is_masked(v_m):
                    for x in xaxis[vw==0]:
                        self.ax.axvline(x, color='r')
                    self.ax.plot(xaxis[np.ma.getmaskarray(v_m)], np.ma.getdata(v_m)[np.ma.getmaskarray(v_m)], 'h', color='r')
            self.ax.legend()
        elif len(yaxis.shape) > 1 and len(plot_labels) == 0:
            for i in range(yaxis.shape[0]):
                v = yaxis[i, :]
                vw = yaxis_weight[i, :]
                v_m = np.ma.masked_where(yaxis_weight == 0, v)
                self.ax.plot(xaxis, v_m, '--', alpha=0.25, color='C' + str(i))
                self.ax.plot(xaxis, v_m, 'h', color='C' + str(i))
        else:
            v = yaxis
            self.ax.plot(xaxis, v, '--', alpha=0.25, color='C0')
            self.ax.plot(xaxis, v, 'h', color='C0')
            vw = yaxis_weight
            v_m = np.ma.masked_where(yaxis_weight == 0, v)
            if np.ma.is_masked(v_m):
                for x in xaxis[vw==0]:
                    self.ax.axvline(x, color='r')
                self.ax.plot(xaxis[np.ma.getmaskarray(v_m)], np.ma.getdata(v_m)[np.ma.getmaskarray(v_m)], 'h', color='r')
        if isphase:
            self.ax.set_ylim(-np.pi, np.pi)
        self.ax.set(xlabel=ax_labels[0], ylabel=ax_labels[1], xlim=limits[0], ylim=limits[1])
        self.canvas.draw()


class GraphWindow2D(QDialog):
    """ A window displaying the plotted 2D quantity. Allows the user to cycle through antenna.
    """
    def __init__(self, values, weights, frametitle, antindex, refantindex, axis, st, polslot=0, direction=0, times=None, freqs=None, pols=None, parent=None):
        """ Initialize a new GraphWindow instance.

        Args:
            values (ndarray): array of values to plot.
            weights (ndarray): array of weights corresponding to the values.
            frametitle (str): title the frame will have.
            antindex (int): the index of the selected antenna.
            polslot (int): index along the polarization axis to start with.
            direction (int): index of the direction to plot.
            parent (QDialog): parent window instance.
        Returns:
            None
        """
        super(GraphWindow2D, self).__init__()
        self.setWindowFlags(QtCore.Qt.WindowSystemMenuHint | QtCore.Qt.WindowMinMaxButtonsHint | QtCore.Qt.WindowCloseButtonHint)
        # Set up for logging output.
        self.LOGGER = logging.getLogger('GraphWindow')
        # LOGGER.setLevel(logging.INFO)
        self.LOGGER.setLevel(logging.DEBUG)

        self.frametitle = frametitle
        self.axis = axis
        self.polslot = polslot
        self.direction = direction
        self.values = values
        self.weights = weights
        self.antindex = antindex
        self.refantindex = refantindex
        self.st = st
        self.parent = parent
        try:
            self.polarizations = pols
        except AttributeError:
            pass

        try:
            self.frequencies = freqs
        except AttributeError:
            # frequencies is None, plotting against time.
            pass

        try:
            self.times = times
            self.times -= self.times[0]
        except AttributeError:
            # times is None, plotting against time.
            pass

        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        self.LOGFILEH = logging.FileHandler('h5plot.log')
        self.LOGFILEH.setLevel(logging.DEBUG)
        self.LOGFILEH.setFormatter(formatter)
        self.LOGGER.addHandler(LOGFILEH)

        self.setWindowTitle(frametitle)

        self.button_next = QPushButton('Forward')
        self.button_next.clicked.connect(self._forward_button_event)
        self.button_prev = QPushButton('Back')
        self.button_prev.clicked.connect(self._backward_button_event)
        self.button_prev.setEnabled(False)
        try:
            if self.st.getType() not in ['rotation', 'tec']:
                self.select_label = QLabel('Corr.: {:s}'.format(self.polarizations[polslot]))
                if len(self.polarizations) == 1:
                    self.button_next.setEnabled(False)
                else:
                    self.button_next.setEnabled(True)
            else:
                raise TypeError
        except TypeError:
            # No pol axis.
            self.select_label = QLabel('')
            self.button_next.setEnabled(False)
            self.button_prev.setEnabled(False)
        self.select_label.setAlignment(QtCore.Qt.AlignCenter)

        self.btn_antiter_next = QPushButton('Next antenna')
        self.btn_antiter_next.clicked.connect(self._antiter_next_button_event)
        self.btn_antiter_prev = QPushButton('Previous antenna')
        self.btn_antiter_prev.clicked.connect(self._antiter_prev_button_event)

        antiter_widget = QWidget()
        antiter_layout = QHBoxLayout(antiter_widget)
        antiter_layout.addWidget(self.btn_antiter_prev)
        antiter_layout.addWidget(self.btn_antiter_next)

        self.buttons = QGridLayout()
        self.buttons.addWidget(self.button_prev, 0, 0)
        self.buttons.addWidget(self.select_label, 0, 1)
        self.buttons.addWidget(self.button_next, 0, 2)

        self.fig = Figure()
        self.canvas = FigureCanvas(self.fig)
        self.toolbar = NavigationToolbar(self.canvas, self)

        self.layout = QVBoxLayout()
        self.layout.addWidget(self.toolbar)
        self.layout.addWidget(self.canvas)
        self.layout.addItem(self.buttons)
        self.layout.addWidget(antiter_widget)
        self.setLayout(self.layout)

    def format_time(self, seconds):
        """ Formats the time to be displayed in the plotting windows.

        A string is formatted, displaying the time in seconds or (fractional) minutes or hours.

        Args:
            seconds (int): the time in seconds.
        Returns:
            formatted time (str): formatted time string.
        """
        if seconds < 60:
            return '{:.3f} sec'.format(seconds)
        elif 60 <= seconds < 3600:
            return '{:.3f} min'.format(seconds / 60)
        elif seconds >= 3600:
            return '{:.3f} hr'.format(seconds / 3600)
        else:
            return '{:.3f}'.format(seconds)

    def _antiter_next_button_event(self):
        if (self.antindex + 1) < len(self.parent.stations):
            self.antindex += 1
        else:
            self.antindex = 0
        diridx = self.parent.directions.index(self.direction)
        x, y, z, zw, p = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, self.polslot, direction=diridx)
        self.frametitle = self.parent.stations[self.antindex]
        self.plot(x, y, z, ax_labels=('Time [s]', 'Freq. [MHz]'), isphase=p, frametitle=self.frametitle)
        self.setWindowTitle(self.frametitle)

    def _antiter_prev_button_event(self):
        if (self.antindex - 1) > 0:
            self.antindex -= 1
        else:
            self.antindex = len(self.parent.stations) - 1
        diridx = self.parent.directions.index(self.direction)
        x, y, z, zw, p = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, self.polslot, direction=diridx)
        self.frametitle = self.parent.stations[self.antindex]
        self.plot(x, y, z, ax_labels=('Time [s]', 'Freq. [MHz]'), isphase=p, frametitle=self.frametitle)
        self.setWindowTitle(self.frametitle)

    def _forward_button_event(self):
        """ An event triggered by pressing the "Forward" button of a GraphWindow.

        When pressed, the abscissa is advanced one position, showing the next time or frequency slot.
        """
        if (self.polslot < (len(self.polarizations) - 1)):
            self.polslot += 1
        if (self.polslot > 0) and (not self.button_prev.isEnabled()):
            self.button_prev.setEnabled(True)
        if self.polslot == (len(self.polarizations) - 1):
            self.button_next.setEnabled(False)
        self.select_label.setText('Corr.: {:s}'.format(self.polarizations[self.polslot]))
        diridx = self.parent.directions.index(self.direction)
        x, y, z, zw, p = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, self.polslot, direction=diridx)
        self.plot(x, y, z, ax_labels=('Time [s]', 'Freq. [MHz]'), isphase=p)

    def _backward_button_event(self):
        """ An event triggered by pressing the "Back" button of a GraphWindow.

        When pressed, the abscissa is set back one position, showing the previous time or frequency slot.
        """
        if (self.polslot > 0):
            self.polslot -= 1
        if self.polslot == 0:
            self.button_prev.setEnabled(False)
        if (self.polslot < (len(self.polarizations) - 1)) and (not self.button_next.isEnabled()):
            self.button_next.setEnabled(True)
        self.select_label.setText('Corr.: {:s}'.format(self.polarizations[self.polslot]))
        diridx = self.parent.directions.index(self.direction)
        x, y, z, zw, p = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, self.polslot, direction=diridx)
        self.plot(x, y, z, ax_labels=('Time [s]', 'Freq. [MHz]'), isphase=p)

    def plot(self, xaxis, yaxis, zaxis, frametitle='', limits=[None, None], ax_labels=['', ''], multidim=False, isphase=False):
        self.xlabel = ax_labels[0]
        self.ylabel = ax_labels[1]
        self.fig.clf()
        self.ax = self.fig.add_subplot(111)
        self.ax.clear()
        # Start counting from t=0
        xaxis = xaxis - xaxis[0]
        yaxis = yaxis * 1e-6
        if self.polarizations is not None:
            self.ax.set_title(frametitle + ':' + self.st.name + ' - {:s} - {:s}'.format(self.polarizations[self.polslot], self.direction))
        else:
            self.ax.set_title(frametitle + ':' + self.st.name + ' - {:s}'.format(self.direction))
        if isphase:
            current_cmap = matplotlib.cm.get_cmap('jet')
            current_cmap.set_bad(color='w')
            im = self.ax.imshow(zaxis.T, interpolation='none', extent=[xaxis[0], xaxis[-1], yaxis[0], yaxis[-1]], aspect='auto', cmap=current_cmap, origin='lower')
            im.set_clim(-np.pi, np.pi)
        else:
            current_cmap = matplotlib.cm.get_cmap('viridis')
            current_cmap.set_bad(color='w')
            im = self.ax.imshow(zaxis.T, interpolation='none', extent=[xaxis[0], xaxis[-1], yaxis[0], yaxis[-1]], aspect='auto', cmap=current_cmap, origin='lower')
        cb = self.fig.colorbar(im)
        self.ax.set(xlabel=ax_labels[0], ylabel=ax_labels[1], xlim=limits[0], ylim=limits[1])
        self.canvas.draw()


class H5PlotGUI(QDialog):
    """The main GUI for H5Plot.

    From here the SolSets, SolTabs and antennas to plot are selected.
    """
    def __init__(self, h5file, logging_instance, parent=None):
        """ Initialize a new instances of the H5PlotGUI.

        Args:
            h5file (str): name of the H5Parm to open.
            logging_instance (logging): an instance of the logging module to log to.
        Returns:
            None
        """
        super(H5PlotGUI, self).__init__(parent)
        self.logger = logging_instance
        self.figures = []

        self.h5parm = lh5.h5parm(h5file)
        self.solset_labels = self.h5parm.getSolsetNames()
        self.solset = self.h5parm.getSolset(self.solset_labels[0])

        self.soltab_labels = self.solset.getSoltabNames()
        self.soltab = self.solset.getSoltab(self.soltab_labels[0])

        for l in self.soltab_labels:
            try:
                self.frequencies = self.solset.getSoltab(l).getAxisValues('freq')
                break
            except TypeError:
                pass
        for l in self.soltab_labels:
            try:
                self.times = self.solset.getSoltab(l).getAxisValues('time')
                break
            except TypeError:
                pass
        for l in self.soltab_labels:
            try:
                if 'pol' in self.solset.getSoltab(l).getAxesNames():
                    self.polarizations = self.solset.getSoltab(l).getAxisValues('pol')
                break
            except TypeError:
                pass
        self.stations = self.soltab.getValues()[1]['ant']
        try:
            self.directions = [s.decode('utf-8') for s in self.solset.getSou().keys()]
        except AttributeError:
            # Probably normal string.
            self.directions = [s for s in self.solset.getSou().keys()]
        self.direction = 0
        self.refant = self.stations[0]
        self.wrapphase = True

        self.stcache = SoltabCache(self.soltab.getValues(), self.soltab.getAxesNames(), weights=self.soltab.getValues(weight=True)[0])
        rvals, rweights, raxes = reorder_soltab(self.soltab)
        self.stcache.update(rvals, raxes, weights=rweights)

        self.move(300, 300)
        self.setWindowTitle('H5Plot')

        self.solset_label = QLabel('SolSet: ')
        self.solset_picker = QComboBox()
        for l in self.solset_labels:
            self.solset_picker.addItem(l)
        self.solset_picker.activated.connect(self._solset_picker_event)

        self.soltab_label_y = QLabel('Plot ')
        self.soltab_label_x = QLabel(' vs ')
        self.soltab_picker = QComboBox()
        for l in self.soltab_labels:
            self.soltab_picker.addItem(l)
        self.soltab_picker.activated.connect(self._soltab_picker_event)
        self.axis_picker = QComboBox()
        self.axis_picker.addItems(['time', 'freq', 'waterfall'])
        self.axis_picker.activated.connect(self._axis_picker_event)
        self.axis = 'time'

        self.refant_label = QLabel('Ref. Ant. ')
        self.refant_picker = QComboBox()
        self.refant_picker.addItems(self.stations)
        self.refant_picker.activated.connect(self._refant_picker_event)

        # self.phasewrap_box = QCheckBox('Wrap Phases')
        # self.phasewrap_box.setChecked(True)
        # self.phasewrap_box.setEnabled(False)
        # self.phasewrap_box.stateChanged.connect(self._phasewrap_event)
        self.dir_label = QLabel('Dir.')
        self.dir_picker = QComboBox()
        self.dir_picker.addItems(self.directions)
        self.dir_picker.activated.connect(self._dir_picker_event)

        self.plot_button = QPushButton('Plot')
        self.plot_button.clicked.connect(self._plot_button_event)

        self.station_picker = QListWidget()
        self.station_picker.addItems(self.stations)
        self.station_picker.setCurrentRow(0)

        plot_layout = QGridLayout()
        plot_layout.addWidget(self.soltab_label_y, 0, 0)
        plot_layout.addWidget(self.soltab_picker, 0, 1)
        plot_layout.addWidget(self.soltab_label_x, 0, 2)
        plot_layout.addWidget(self.axis_picker, 0, 3)
        plot_layout.addWidget(self.refant_label, 1, 0)
        plot_layout.addWidget(self.refant_picker, 1, 1)
        plot_layout.addWidget(self.dir_label, 1, 2,)
        plot_layout.addWidget(self.dir_picker, 1, 3)

        layout = QFormLayout(self)
        layout.addRow(self.solset_label, self.solset_picker)
        layout.addRow(plot_layout)
        layout.addRow(self.plot_button)
        layout.addRow(self.station_picker)

    def _axis_picker_event(self):
        """Callback function for when the x-axis is changed.

        Sets the `axis` attribute to the selected axis
        """
        self.logger.debug('Axis changed to: ' + self.axis_picker.currentText())
        self.axis = self.axis_picker.currentText()

    def closeEvent(self, event):
        """ The event triggerd upon closing the main application window.
        """
        self.logger.info('Closing all open figures before exiting.')
        plt.close('all')
        for f in self.figures:
            f.close()
        event.accept()

    def _refant_picker_event(self):
        """ An even triggered when a new reference antenna is selected.

        Sets the `refant` attribute.
        """
        self.logger.debug('Reference antenna changed to: ' + self.refant_picker.currentText())
        self.refant = self.refant_picker.currentText()

    def _solset_picker_event(self):
        """Callback function for when the SolSet is changed.

        Sets the `solset` attribute.
        """
        self.logger.debug('Solset changed to: ' + self.solset_picker.currentText())
        self.solset = self.h5parm.getSolset(self.solset_picker.currentText())
        self.soltab_labels = self.solset.getSoltabNames()
        self.soltab_picker.clear()
        for l in self.soltab_labels:
            self.soltab_picker.addItem(l)
        self._soltab_picker_event()

    def _soltab_picker_event(self):
        """Callback function for when the SolTab is changed.

        Sets the `soltab` attribute.
        """
        self.logger.debug('Soltab changed to: ' + self.soltab_picker.currentText())
        self.soltab = self.solset.getSoltab(self.soltab_picker.currentText())
        stations_old = self.stations
        self.stations = self.soltab.getValues()[1]['ant']
        if not np.array_equiv(stations_old, self.stations):
            self.logger.debug('Number of stations changed, updating list.')
            # The list of stations has changed, update the list.
            self.station_picker.clear()
            self.station_picker.addItems(self.stations)
            self.refant_picker.clear()
            self.refant_picker.addItems(self.stations)
        try:
            self.frequencies = self.soltab.getAxisValues('freq')
        except TypeError:
            # Soltab probably has no frequency axis.
            pass
        rvals, rweights, raxes = reorder_soltab(self.soltab)
        self.stcache.update(rvals, raxes, weights=rweights)

    def _dir_picker_event(self):
        """Callback function for when the direction is changed.

        Sets the `direction` attribute.
        """
        self.logger.debug('Direction changed to: ' + self.dir_picker.currentText())
        self.direction = self.dir_picker.currentIndex()

    def _phasewrap_event(self):
        """ An even triggered upon switching phase wrapping on or off. (not yet implemented)
        """
        self.logger.debug('Phase wrapping changed to ' + str(self.phasewrap_box.isChecked()))
        self.wrapphase = self.phasewrap_box.isChecked()

    def _plot_button_event(self):
        """Callback function for when the plot button is pressed.

        Calls the `plot` function subsecquently.
        """
        self.logger.debug('Plotting button pressed.')
        if self.axis == 'freq' or self.axis == 'time':
            self.plot(labels=(self.axis, self.soltab.name))
        elif self.axis == 'waterfall':
            self.plot_waterfall(labels=('time', 'freq'))

    def plot_waterfall(self, labels=('x-axis', 'y-axis')):
        """ Show a two-dimensional waterfall plot of time vs. frequency.
        """
        if (('rotationmeasure' in self.soltab.name) or ('RMextract' in self.soltab.name) or ('clock' in self.soltab.name) or ('faraday' in self.soltab.name) or ('tec' in self.soltab.name)):
            self.logger.info('Rotation Measure, clock, faraday or TEC cannot be plotted in 2D!')
            return
        self.logger.info('Plotting ' + self.soltab.name + \
                         ' for ' + self.solset.name)
        antenna = self.station_picker.currentRow()
        # Data loaded here is xaxis, yaxis, zaxis, isphase
        print('Loading data')
        msg = load_axes_2d(self.stcache.values, self.stcache.weights, self.soltab, antenna=antenna, refantenna=int(np.argwhere(self.stations == self.refant)), pol=0, direction=self.direction)
        try:
            x, y, z, zw, p = msg
        except ValueError:
            logging.error('Error loading 2D data!')
            return
        if (len(x) == 1) or (len(y) == 1):
            self.logger.info('Either time or frequency has only 1 entry, not plotting!')
            return
        try:
            plot_window = GraphWindow2D(self.stcache.values, self.stcache.weights, self.stations[antenna], antenna, int(np.argwhere(self.stations == self.refant)), self.axis, self.soltab, times=self.times, freqs=self.frequencies, pols=self.polarizations, parent=self, direction=self.directions[self.direction])
        except AttributeError:
            # No polarizations most likely.
            plot_window = GraphWindow2D(self.stcache.values, self.stcache.weights, self.stations[antenna], antenna, int(np.argwhere(self.stations == self.refant)), self.axis, self.soltab, times=self.times, freqs=self.frequencies, pols=['N/A'], parent=self, direction=self.directions[self.direction])
        self.figures.append(plot_window)
        z[zw == 0] = np.nan
        plot_window.plot(x, y, z, ax_labels=('Time [s]', 'Freq. [MHz]'), isphase=p, frametitle=self.stations[antenna])
        plot_window.show()

    def plot(self, labels=('x-axis', 'y-axis'), limits=([None, None], [None, None])):
        self.logger.info('Plotting ' + self.soltab.name + ' vs ' + self.axis + \
                         ' for ' + self.solset.name)
        antenna = self.station_picker.currentRow()
        if (('rotationmeasure' in self.soltab.name) or ('RMextract' in self.soltab.name) or ('clock' in self.soltab.name) or ('faraday' in self.soltab.name) or ('tec' in self.soltab.name)) and (self.axis == 'freq'):
            self.logger.info('Rotation Measure or clock does not support frequency axis! Switch to time instead.')
            return
        msg = load_axes(self.stcache.values, self.soltab, self.axis, antenna=antenna, refantenna=int(np.argwhere(self.stations == self.refant)), direction=self.direction, weights=self.stcache.weights)
        try:
            x_axis, Y_AXIS, Y_AXIS_WEIGHT, plabels, isphase = msg
        except ValueError:
            # Requested combination not supported.
            return
        if 'freq' in self.soltab.getAxesNames():
            plot_window = GraphWindow(self.stcache.values, self.stcache.weights, self.stations[antenna], antenna, int(np.argwhere(self.stations == self.refant)), self.axis, self.soltab, times=self.times, freqs=self.frequencies, parent=self, direction=self.directions[self.direction])
        else:
            # Probably TEC or another solution type with no frequency axis.
            plot_window = GraphWindow(self.stcache.values, self.stcache.weights, self.stations[antenna], antenna, int(np.argwhere(self.stations == self.refant)), self.axis, self.soltab, times=self.times, parent=self, direction=self.directions[self.direction])
        self.figures.append(plot_window)
        plot_window.plot(x_axis, Y_AXIS, Y_AXIS_WEIGHT, self.stations[antenna], limits=[None, None], ax_labels=[self.axis, labels[1]], plot_labels=plabels, isphase=isphase)
        plot_window.show()

        # TEC does not have a frequency axis, so disable the button as well.
        if 'tec' in self.soltab.name:
            self.logger.debug('TEC solutions detected, disabling buttons.')
            plot_window.button_next.setEnabled(False)
            plot_window.button_prev.setEnabled(False)
        if self.axis.lower() == 'freq' and (len(self.times) == 1):
            plot_window.button_next.setEnabled(False)
            self.logger.debug('Single time slot detected, disabling buttons.')
        elif self.axis.lower() == 'freq' and (len(self.times) > 1):
            self.logger.debug('Multiple time slots detected, enabling buttons.')
            plot_window.button_next.setEnabled(True)
        if self.axis.lower() == 'time' and (self.frequencies is None or (len(self.frequencies) == 1)):
            plot_window.button_next.setEnabled(False)
            self.logger.debug('Single frequency slot detected, disabling buttons.')
        elif self.axis.lower() == 'time' and (len(self.frequencies) > 1):
            self.logger.debug('Multiple frequency slots detected, enabling buttons.')
            plot_window.button_next.setEnabled(True)


class SoltabCache:
    '''Simple class just to store temporarily reordered soltab data.'''
    def __init__(self, values, axes, weights=None):
        """ Initialize a new SoltabCache instance.

        Args:
            values (ndarray): values to cache.
            axes (ndarray): axes to store.
            weights (ndarray or None): weights to cache.
        Returns:
            None
        """
        self.values = values
        self.axes = axes
        self.weights = weights

    def update(self, nvalues, naxes, weights=None):
        """ Update the data in the cache.

        Args:
            nvalues (ndarray): new values to store in the cache.
            naxes (ndarray): new axes to store in the cache.
            weights (ndarray or None): new weghts to store in the cache.
        Returns:
            None
        """
        self.values = nvalues
        self.axes = naxes
        self.weights = weights


# Global helper functions.
def reorder_soltab(st):
    """ Reorder a soltab in the order H5plot expects.

    The expected order in the plotter is time, frequency, antenna, polarization, direction.

    Args:
        st (SolTab): soltab instance to reorder the axes of.
    Returns:
        st_new (tuple): tuple of (values, weights, axes) reodered to the expected order.
        order_new (ndarray): array containing the reordered order of the axes.
    """
    LOGGER.info('Reordering soltab ' + st.name)
    order_old = st.getAxesNames()
    if ('rotationmeasure' in st.name) or ('RMextract'in st.name) or ('clock' in st.name) or ('faraday' in st.name) or ('tec' in st.name and 'freq' not in order_old):
        order_new = ['time', 'ant']
    else:
        order_new = ['time', 'freq', 'ant']
    if 'pol' in order_old:
        order_new += ['pol']
    if 'dir' in order_old:
        order_new += ['dir']
    reordered = reorderAxes(st.getValues()[0], order_old, order_new)
    reordered_weights = reorderAxes(st.getValues(weight=True)[0], order_old, order_new)
    reordered2 = {}
    for k in order_new:
        reordered2[k] = st.axes[k]
    st.axes = reordered2
    st.axesNames = order_new
    st_new = (reordered, st.getValues()[1])
    return st_new, reordered_weights, order_new


def wrap_phase(phase):
    """ Map phases to the range -pi, pi.

    The formula (phase + np.pi) % (2 * np.pi) - np.pi is used to map phases into a plottable range.

    Args:
        phase (ndarray): narray of phases to remap.
    Returns:
        wphase (ndarray): narray of remapped phases.
    """
    wphase = (phase + np.pi) % (2 * np.pi) - np.pi
    return wphase


if __name__ == '__main__':
    print('H5Plot ' + __version__)
    print('Author: Frits Sweijen (frits.sweijen@gmail.com)')
    print('Released under GPLv3\n')
    import argparse
    parser = argparse.ArgumentParser(description='Interactive plotter to explore LOFAR H5parm solutions.')
    parser.add_argument('FILENAME')
    args = parser.parse_args()

    H5FILE = lh5.h5parm(args.FILENAME, readonly=True)
    # Set up for logging output.
    LOGGER = logging.getLogger('H5plot_logger')
    if '--debug' in sys.argv:
        LOGGER.setLevel(logging.DEBUG)
    else:
        LOGGER.setLevel(logging.INFO)

    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    LOGFILEH = logging.FileHandler('h5plot.log')
    LOGFILEH.setLevel(logging.DEBUG)
    LOGFILEH.setFormatter(formatter)
    LOGGER.addHandler(LOGFILEH)

    LOGGER.info('Successfully opened %s', args.FILENAME)

    SOLSETS = H5FILE.getSolsetNames()
    print('Found solset(s) ' + ', '.join(SOLSETS))
    for solset in SOLSETS:
        print('SolTabs in ' + solset + ':')
        ss = H5FILE.getSolset(solset)
        soltabs = ss.getSoltabNames()
        print('\t' + ', '.join(soltabs))

    # Initialize the GUI.
    APP = QApplication(sys.argv)
    GUI = H5PlotGUI(args.FILENAME, LOGGER)
    GUI.show()
    APP.exec_()

    H5FILE.close()
    LOGGER.info('%s successfully closed.', args.FILENAME)
