#!/usr/bin/env python3
""" The spiritual successor to ParmDBplot for quickly reviewing gain solutions generated by NDPPP.
"""
__version__ = 'v2.9.0'
import logging
import signal
import sys
import time

from PyQt5.QtWidgets import QApplication, QCheckBox, QComboBox, QDialog, QFormLayout, QGridLayout, QHBoxLayout, QLabel, \
    QLineEdit, QListWidget, QPushButton, QScrollBar, QVBoxLayout, QWidget
from PyQt5 import QtCore, QtWidgets

from losoto.lib_operations import reorderAxes
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
from matplotlib.backends.backend_qt5agg import NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure

import losoto.h5parm as lh5
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
matplotlib.use('Qt5Agg')
plt.ion()

signal.signal(signal.SIGINT, signal.SIG_DFL)


def load_axes(vals, st, axis_type, antenna, refantenna, timeslot=0, freqslot=0, direction=0, weights=None):
    """ Load an abscissa and ordinate from the H5Parm.

    Args:
        vals (ndarray): raw soltab values to load.
        st_type (str): string describing the type of solutions (e.g. phase, clock, amplitude).
        axis_type (str): `time` or `freq`.
        antenna (str): name of the antenna to select. Automatically determined if set to `None`.
        refant (int): index of the reference antenna.
        timeslot (int): timeslot to load.
        freqslot (int): frequency slot to load.
        direction (int): direction to load.
        weights (ndarray or None): weights for soltab values.
    Returns:
        xaxis (ndarray): an absicssa to plot.
        yaxis (ndarray): an ordinate to plot.
        yaxis_weight (ndarray): weights for the ordinate to plot.
        plabels (list): a list of labels for each plot (e.g. different polarizations).
        isphase (bool): boolean indicating whether or not the quantity is a phase.

        OR

        errorcode (str): an error message if things went wrong.
    """
    wrapphase = True
    # Values have shape (timestamps, frequencies, antennas, polarizations, directions).
    axes = st.getAxesNames()
    st_type = st.getType()
    x_axis = vals[1][axis_type]
    values = vals[0]
    plabels = []
    isphase = False
    if weights is None:
        weights = np.ones(values.shape)

    if axis_type == 'time':
        if ('rotationmeasure' in st.name) or ('faraday' in st.name) or ('tec' in st.name and 'freq' not in axes and 'dir' not in axes):
            y_axis = values[:, antenna] - values[:, refantenna]
            y_axis_weight = weights[:, antenna]
            Y_AXIS = y_axis
            Y_AXIS_WEIGHT = y_axis_weight
        elif ('pol' in axes) and ('dir' in axes):
            if st_type == 'phase':
                isphase = True
                # Plot phase-like quantities w.r.t. to a reference antenna.
                y_axis = values[:, freqslot, antenna, :, direction] - values[:, freqslot, refantenna, :, direction]
                y_axis_weight = weights[:, freqslot, antenna, :, direction]
                if wrapphase:
                    y_axis = wrap_phase(y_axis)
            elif (st_type == 'clock') or (st_type == 'rotationmeasure') or ((st_type == 'tec' or st_type == 'phase') and 'freq' not in axes):
                y_axis = values[:, antenna, direction] - values[:, refantenna, direction]
                y_axis_weight = weights[:, antenna, direction]
            else:
                y_axis = values[:, freqslot, antenna, :, direction]
                y_axis_weight = weights[:, freqslot, antenna, :, direction]
            Y_AXIS = []
            Y_AXIS_WEIGHT = []
            plabels = []
            # Iterate over polarizations.
            for i in range(y_axis.shape[1]):
                Y_AXIS.append(y_axis[:, i])
                Y_AXIS_WEIGHT.append(y_axis_weight[:, i])
                plabels.append(vals[1]['pol'][i])
        elif 'pol' in axes:
            if st_type == 'phase':
                isphase = True
                # Plot phase-like quantities w.r.t. to a reference antenna.
                y_axis = values[:, freqslot, antenna, :] - values[:, freqslot, refantenna, :]
                y_axis_weight = weights[:, freqslot, antenna, :]
                if wrapphase:
                    y_axis = wrap_phase(y_axis)
            elif (st_type == 'clock') or (st_type == 'rotationmeasure') or (st_type == 'tec') or (st_type == 'phase' and 'freq' not in axes):
                y_axis = values[:, antenna] - values[:, refantenna]
                y_axis_weight = weights[:, antenna]
            else:
                y_axis = values[:, freqslot, antenna, :]
                y_axis_weight = weights[:, freqslot, antenna, :]
            Y_AXIS = []
            Y_AXIS_WEIGHT = []
            plabels = []
            for i in range(y_axis.shape[1]):
                Y_AXIS.append(y_axis[:, i])
                Y_AXIS_WEIGHT.append(y_axis_weight[:, i])
                plabels.append(vals[1]['pol'][i])
        elif 'dir' in axes:
            if st_type == 'phase':
                isphase = True
                # Plot phase-like quantities w.r.t. to a reference antenna.
                y_axis = values[:, freqslot, antenna, direction] - values[:, freqslot, refantenna, direction]
                y_axis_weight = weights[:, freqslot, antenna, direction]
                if wrapphase:
                    y_axis = wrap_phase(y_axis)
            elif (st_type == 'clock') or (st_type == 'rotationmeasure') or ((st_type == 'tec' or st_type == 'phase') and 'freq' not in axes):
                y_axis = values[:, antenna, direction] - values[:, refantenna, direction]
                y_axis_weight = weights[:, antenna, direction]
            else:
                y_axis = values[:, freqslot, antenna, direction]
                y_axis_weight = weights[:, freqslot, antenna, direction]
            Y_AXIS = y_axis
            Y_AXIS_WEIGHT = y_axis_weight
        elif ('pol' not in axes) and ('dir' not in axes):
            if (st_type == 'clock') or (st_type == 'rotationmeasure') or ((st_type == 'tec' or st_type == 'phase') and 'freq' not in axes):
                y_axis = values[:, antenna] - values[:, refantenna]
                y_axis_weight = weights[:, antenna]
            elif ((st_type == 'tec' or st_type == 'phase') and 'freq' in axes):
                y_axis = values[:, 0, antenna] - values[:, 0, refantenna]
                y_axis_weight = weights[:, 0, antenna]
            else:
                # Assume normal time, freq, antenna ordering.
                y_axis = values[:, freqslot, antenna]
                y_axis_weight = weights[:, freqslot, antenna]
            Y_AXIS = y_axis
            Y_AXIS_WEIGHT = y_axis_weight
    elif axis_type == 'freq':
        if ('rotationmeasure' in st.name) or ('clock' in st.name) or ('faraday' in st.name) or ('tec' in st.name):
            logging.warning('Rotation Measure does not support frequency axis! Switch to time instead.')
        if ('pol' in axes) and ('dir' in axes):
            if st_type == 'phase':
                isphase = True
                # Plot phase-like quantities w.r.t. to a reference antenna.
                y_axis = values[timeslot, :, antenna, :, direction] - values[timeslot, :, refantenna, :, direction]
                y_axis_weight = weights[timeslot, :, antenna, :, direction]
                if wrapphase:
                    y_axis = wrap_phase(y_axis)
            else:
                y_axis = values[timeslot, :, antenna, :, direction]
                y_axis_weight = weights[timeslot, :, antenna, :, direction]
            Y_AXIS = []
            Y_AXIS_WEIGHT = []
            plabels = []
            for i in range(y_axis.shape[1]):
                Y_AXIS.append(y_axis[:, i])
                Y_AXIS_WEIGHT.append(y_axis_weight[:, i])
                plabels.append(vals[1]['pol'][i])
        elif 'pol' in axes:
            if st_type == 'phase':
                isphase = True
                # Plot phase-like quantities w.r.t. to a reference antenna.
                y_axis = values[timeslot, :, antenna, :] - values[timeslot, :, refantenna, :]
                y_axis_weight = weights[timeslot, :, antenna, :]
                if wrapphase:
                    y_axis = wrap_phase(y_axis)
            else:
                y_axis = values[timeslot, :, antenna, :]
                y_axis_weight = weights[timeslot, :, antenna, :]
            Y_AXIS = []
            Y_AXIS_WEIGHT = []
            plabels = []
            for i in range(y_axis.shape[1]):
                Y_AXIS.append(y_axis[:, i])
                Y_AXIS_WEIGHT.append(y_axis_weight[:, i])
                plabels.append(vals[1]['pol'][i])
        elif 'dir' in axes:
            if st_type == 'phase':
                isphase = True
                # Plot phase-like quantities w.r.t. to a reference antenna.
                y_axis = values[timeslot, :, antenna, direction] - values[timeslot, :, refantenna, direction]
                y_axis_weight = weights[timeslot, :, antenna, direction]
                if wrapphase:
                    y_axis = wrap_phase(y_axis)
            else:
                y_axis = values[timeslot, :, antenna, direction]
                y_axis_weight = weights[timeslot, :, antenna, direction]
            Y_AXIS = y_axis
            Y_AXIS_WEIGHT = y_axis_weight
        elif ('pol' not in axes) and ('dir' not in axes):
            y_axis = values[timeslot, :, antenna]
            y_axis_weight = weights[timeslot, :, antenna]
            Y_AXIS = y_axis
            Y_AXIS_WEIGHT = y_axis_weight
    if len(plabels) == 0:
        plabels = ['', '', '', '']
    return x_axis, Y_AXIS, Y_AXIS_WEIGHT, plabels, isphase


def load_axes_2d(vals, weights, st, antenna, refantenna, pol=0, direction=0):
    """ Load a 2D slice from the H5Parm.

    Args:
        vals (ndarray): raw soltab values to load as obtained from Soltab.getValues().
        weights (ndarray): weights corresponding to vals.
        st (soltab): soltab object.
        antenna (str): name of the antenna to select. Automatically determined if set to `None`.
        refant (int): index of the reference antenna.
        pol (int): polarization to load.
        direction (int): direction to load.
    Returns:
        x_axis (ndarray): time axis
        y_axis (ndarray): frequency axis
        plotvals (ndarray): 2d ndarray of soltab values
        isphase (bool): boolean indicating whether or not the quantity is a phase.

        OR

        errorcode (str): an error message if things went wrong.
    """
    wrapphase = True
    # Values have shape (timestamps, frequencies, antennas, polarizations, directions).
    axes = st.getAxesNames()
    st_type = st.getType()
    x_axis = vals[1]['time']
    y_axis = vals[1]['freq']
    values = np.asarray(vals[0])
    plabels = []

    try:
        pols = list(vals[1]['pol'])
    except:
        logging.debug('No polarization axis present.')
    if (st_type == 'phase') or (st_type == 'rotation'):
        isphase = True
    else:
        isphase = False
    if isphase:
        if 'polalign' in st.name:
            # Special case. We must plot XX-YY and XX is 0 so not informative anyway.
            # There's probably a nicer way of doing this.
            # This by definition has polarizations, no need to check.
            if 'dir' in axes:
                plotvals = (values[:, :, antenna, pols.index('XX'), direction] - values[:, :, antenna, pols.index('YY'), direction]) - (values[:, :, refantenna, pols.index('XX'), direction] - values[:, :, refantenna, pols.index('YY'), direction])
                plotvals_weight = weights[:, :, antenna, pol, direction]
            else:
                plotvals = (values[:, :, antenna, pols.index('XX')] - values[:, :, antenna, pols.index('YY')]) - (values[:, :, refantenna, pols.index('XX')] - values[:, :, refantenna, pols.index('YY')])
                plotvals_weight = weights[:, :, antenna, pol]
        else:
            if 'pol' in axes:
                if 'dir' in axes:
                    plotvals = values[:, :, antenna, pol, direction] - values[:, :, refantenna, pol, direction]
                    plotvals_weight = weights[:, :, antenna, pol, direction]
                else:
                    plotvals = values[:, :, antenna, pol] - values[:, :, refantenna, pol]
                    plotvals_weight = weights[:, :, antenna, pol]
            elif 'dir' in axes:
                plotvals = values[:, :, antenna, direction] - values[:, :, refantenna, direction]
                plotvals_weight = weights[:, :, antenna, direction]
            else:
                plotvals = values[:, :, antenna] - values[:, :, refantenna]
                plotvals_weight = weights[:, :, antenna]
        if wrapphase:
            plotvals = wrap_phase(plotvals)
    else:
        if 'pol' in axes:
            if 'dir' in axes:
                plotvals = values[:, :, antenna, pol, direction]
                plotvals_weight = weights[:, :, antenna, pol, direction]
            else:
                plotvals = values[:, :, antenna, pol]
                plotvals_weight = weights[:, :, antenna, pol]
        elif 'dir' in axes:
            plotvals = values[:, :, antenna, direction]
            plotvals_weight = weights[:, :, antenna, direction]
        else:
            plotvals = values[:, :, antenna] - values[:, :, refantenna]
            plotvals_weight = weights[:, :, antenna]
    return x_axis, y_axis, plotvals, plotvals_weight, isphase


class GraphWindow(QDialog):
    """ A window displaying the plotted quantity. Allows the user to cycle through time or frequency.
    """
    def __init__(self, values, weights, frametitle, antindex, refantindex, axis, st, timeslot=0, freqslot=0, direction=0, times=None, freqs=None, parent=None, mode='values', do_timediff=False, do_freqdiff=False, do_poldiff=False):
        """ Initialize a new GraphWindow instance.

        Args:
            frametitle (str): title the frame will hvae.
            antindex (int): the index of the selected antenna.
            axis (str): the type of axis being plotted (time or freq).
            timeslot (int): index along the time axis to start with.
            freqslot (int): index along the frequency axis to start with.
            direction (str): name of the direction to plot.
            parent (QDialog): parent window instance.
        Returns:
            None
        """
        super(GraphWindow, self).__init__()
        self.setWindowFlags(QtCore.Qt.WindowSystemMenuHint | QtCore.Qt.WindowMinMaxButtonsHint | QtCore.Qt.WindowCloseButtonHint)
        # Set up for logging output.
        self.LOGGER = logging.getLogger('GraphWindow')
        # LOGGER.setLevel(logging.INFO)
        self.LOGGER.setLevel(logging.DEBUG)

        self.frametitle = frametitle
        self.axis = axis
        self.timeslot = 0
        self.freqslot = 0
        self.direction = direction
        self.values = values
        self.weights = weights
        self.antindex = antindex
        self.refantindex = refantindex
        self.st = st
        self.parent = parent
        self.mode = mode
        self.do_timediff = do_timediff
        self.do_freqdiff = do_freqdiff
        self.do_poldiff = do_poldiff
        try:
            self.frequencies = freqs
        except AttributeError:
            # frequencies is None, plotting against time.
            pass

        try:
            self.times = times
            self.times -= self.times[0]
        except AttributeError:
            # times is None, plotting against time.
            pass

        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        self.LOGFILEH = logging.FileHandler('h5plot.log')
        self.LOGFILEH.setLevel(logging.DEBUG)
        self.LOGFILEH.setFormatter(formatter)
        self.LOGGER.addHandler(LOGFILEH)

        self.setWindowTitle(frametitle)

        self.button_next = QPushButton('Forward')
        self.button_next.clicked.connect(self._forward_button_event)
        self.button_prev = QPushButton('Back')
        self.button_prev.clicked.connect(self._backward_button_event)
        self.button_prev.setEnabled(False)
        if 'time' in axis.lower():
            try:
                self.select_label = QLabel('Freq slot {:.2f} MHz'.format(self.frequencies[freqslot] / 1e6))
                if len(self.frequencies == 1):
                    self.button_next.setEnabled(False)
                else:
                    self.button_next.setEnabled(True)
            except TypeError:
                # No frequency axis.
                self.select_label = QLabel('')
                self.button_next.setEnabled(False)
                self.button_prev.setEnabled(False)
        elif 'freq' in axis.lower():
            try:
                self.select_label = QLabel('Time: ' + self.format_time(timeslot))
                if len(self.times == 1):
                    self.button_next.setEnabled(False)
                else:
                    self.button_next.setEnabled(True)
            except TypeError:
                # No time axis.
                self.select_label = QLabel('')
                self.button_next.setEnabled(False)
                self.button_prev.setEnabled(False)
        self.select_label.setAlignment(QtCore.Qt.AlignCenter)

        self.btn_antiter_next = QPushButton('Next antenna')
        self.btn_antiter_next.clicked.connect(self._antiter_next_button_event)
        self.btn_antiter_prev = QPushButton('Previous antenna')
        self.btn_antiter_prev.clicked.connect(self._antiter_prev_button_event)

        self.btn_diriter_next = QPushButton('Next direction')
        self.btn_diriter_next.clicked.connect(self._diriter_next_button_event)
        self.btn_diriter_prev = QPushButton('Previous direction')
        self.btn_diriter_prev.clicked.connect(self._diriter_prev_button_event)

        antiter_widget = QWidget()
        antiter_layout = QHBoxLayout(antiter_widget)
        antiter_layout.addWidget(self.btn_antiter_prev)
        antiter_layout.addWidget(self.btn_antiter_next)
        antiter_layout.addWidget(self.btn_diriter_prev)
        antiter_layout.addWidget(self.btn_diriter_next)

        self.buttons = QGridLayout()
        self.buttons.addWidget(self.button_prev, 0, 0)
        self.buttons.addWidget(self.select_label, 0, 1)
        self.buttons.addWidget(self.button_next, 0, 2)

        self.scrolls = QGridLayout()
        self.scrollbar = QScrollBar()
        self.scrollbar.setOrientation(QtCore.Qt.Horizontal)
        if 'time' in axis.lower() and self.frequencies is not None:
            self.scrollbar.setRange(0, len(self.frequencies)-1)
        elif 'freq' in axis.lower()  and self.times is not None:
            self.scrollbar.setRange(0, len(self.times)-1)
        else:
            self.scrollbar.setDisabled(True)
        self.scrollbar.valueChanged.connect(self._scrollbar_event)
        self.scrolls.addWidget(self.scrollbar, 0, 0)

        self.fig = Figure()
        self.canvas = FigureCanvas(self.fig)
        self.toolbar = NavigationToolbar(self.canvas, self)

        self.layout = QVBoxLayout()
        self.layout.addWidget(self.toolbar)
        self.layout.addWidget(self.canvas, stretch=500)
        self.layout.addItem(self.buttons)
        self.layout.addItem(self.scrolls)
        self.layout.addWidget(antiter_widget)
        self.setLayout(self.layout)

    def format_time(self, seconds):
        """ Formats the time to be displayed in the plotting windows.

        A string is formatted, displaying the time in seconds or (fractional) minutes or hours.

        Args:
            seconds (int): the time in seconds.
        Returns:
            formatted time (str): formatted time string.
        """
        if seconds < 60:
            return '{:.3f} sec'.format(seconds)
        elif 60 <= seconds < 3600:
            return '{:.3f} min'.format(seconds / 60)
        elif seconds >= 3600:
            return '{:.3f} hr'.format(seconds / 3600)
        else:
            return '{:.3f}'.format(seconds)

    def _antiter_next_button_event(self):
        if (self.antindex + 1) < len(self.parent.stations):
            self.antindex += 1
        else:
            self.antindex = 0
        diridx = self.parent.directions.index(self.direction)
        x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, freqslot=self.freqslot, direction=diridx, weights=self.weights)
        self.frametitle = self.parent.stations[self.antindex]
        self.plot(x, y, yw, self.frametitle, ax_labels=[self.xlabel, self.ylabel], plot_labels=l, isphase=p)
        self.setWindowTitle(self.frametitle)

    def _antiter_prev_button_event(self):
        if (self.antindex - 1) > 0:
            self.antindex -= 1
        else:
            self.antindex = len(self.parent.stations) - 1
        diridx = self.parent.directions.index(self.direction)
        x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, freqslot=self.freqslot, direction=diridx, weights=self.weights)
        self.frametitle = self.parent.stations[self.antindex]
        self.plot(x, y, yw, self.frametitle, ax_labels=[self.xlabel, self.ylabel], plot_labels=l, isphase=p)
        self.setWindowTitle(self.frametitle)

    def _diriter_next_button_event(self):
        diridx = self.parent.directions.index(self.direction)
        if (diridx + 1) < len(self.parent.directions):
            diridx += 1
        else:
            diridx = 0
        self.direction = self.parent.directions[diridx]
        x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, freqslot=self.freqslot, direction=diridx, weights=self.weights)
        self.frametitle = self.parent.stations[self.antindex]
        self.plot(x, y, yw, self.frametitle, ax_labels=[self.xlabel, self.ylabel], plot_labels=l, isphase=p)
        self.setWindowTitle(self.frametitle)

    def _diriter_prev_button_event(self):
        diridx = self.parent.directions.index(self.direction)
        if (diridx - 1) > 0:
            diridx -= 1
        else:
            diridx = -1
        self.direction = self.parent.directions[diridx]
        x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, freqslot=self.freqslot, direction=diridx, weights=self.weights)
        self.frametitle = self.parent.stations[self.antindex]
        self.plot(x, y, yw, self.frametitle, ax_labels=[self.xlabel, self.ylabel], plot_labels=l, isphase=p)
        self.setWindowTitle(self.frametitle)

    def _forward_button_event(self):
        """ An event triggered by pressing the "Forward" button of a GraphWindow.

        When pressed, the abscissa is advanced one position, showing the next time or frequency slot.
        """
        if 'time' in self.xlabel.lower():
            self.freqslot += 1
            self.select_label.setText('Frequency: {:.3f} MHz'.format(self.frequencies[self.freqslot] / 1e6))
            diridx = self.parent.directions.index(self.direction)
            x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, freqslot=self.freqslot, direction=diridx, weights=self.weights)
            if (self.freqslot > 0) and (not self.button_prev.isEnabled()):
                self.button_prev.setEnabled(True)
            if self.freqslot == (len(self.frequencies) - 1):
                self.button_next.setEnabled(False)
            self.scrollbar.setValue(self.freqslot)
                
        elif 'freq' in self.xlabel.lower():
            self.timeslot += 1
            self.select_label.setText('Time: ' + self.format_time(self.times[self.timeslot]))
            diridx = self.parent.directions.index(self.direction)
            x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, timeslot=self.timeslot, direction=diridx, weights=self.weights)
            if self.timeslot < (len(self.times) - 1) and (not self.button_prev.isEnabled()):
                self.button_prev.setEnabled(True)
            if self.timeslot == (len(self.times) - 1):
                self.button_next.setEnabled(False)
            self.scrollbar.setValue(self.timeslot)
        self.plot(x, y, yw, self.frametitle, ax_labels=[self.xlabel, self.ylabel], plot_labels=l, isphase=p)

    def _backward_button_event(self):
        """ An event triggered by pressing the "Back" button of a GraphWindow.

        When pressed, the abscissa is set back one position, showing the previous time or frequency slot.
        """
        if 'time' in self.xlabel.lower():
            if self.freqslot > 0:
                self.freqslot -= 1
                self.select_label.setText('Frequency: {:.3f} MHz'.format(self.frequencies[self.freqslot] / 1e6))
                diridx = self.parent.directions.index(self.direction)
                x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, freqslot=self.freqslot, direction=diridx, weights=self.weights)
                if self.freqslot == 0:
                    self.button_prev.setEnabled(False)
                if (self.freqslot < (len(self.frequencies) - 1)) and (not self.button_next.isEnabled()):
                    self.button_next.setEnabled(True)
                self.scrollbar.setValue(self.freqslot)
        elif 'freq' in self.xlabel.lower():
            if self.timeslot > 0:
                self.timeslot -= 1
                self.select_label.setText('Time: ' + self.format_time(self.times[self.timeslot]))
                diridx = self.parent.directions.index(self.direction)
                x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, timeslot=self.timeslot, direction=diridx, weights=self.weights)
                if self.timeslot == 0:
                    self.button_prev.setEnabled(False)
                if (self.timeslot < (len(self.parent.times) - 1)) and (not self.button_next.isEnabled()):
                    self.button_next.setEnabled(True)
                self.scrollbar.setValue(self.timeslot)
        self.plot(x, y, yw, self.frametitle, ax_labels=[self.xlabel, self.ylabel], plot_labels=l, isphase=p)

    def _scrollbar_event(self):
        if 'time' in self.xlabel.lower() and len(self.frequencies) > 1:
            self.freqslot = self.scrollbar.value()
            self.select_label.setText('Frequency: {:.3f} MHz'.format(self.frequencies[self.freqslot] / 1e6))
            diridx = self.parent.directions.index(self.direction)
            x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, freqslot=self.freqslot, direction=diridx, weights=self.weights)
            if (self.freqslot > 0) and (not self.button_prev.isEnabled()):
                self.button_prev.setEnabled(True)
            if self.freqslot == (len(self.frequencies) - 1):
                self.button_next.setEnabled(False)
                
        elif 'freq' in self.xlabel.lower():
            self.timeslot = self.scrollbar.value()
            self.select_label.setText('Time: ' + self.format_time(self.times[self.timeslot]))
            diridx = self.parent.directions.index(self.direction)
            x, y, yw, l, p = load_axes(self.values, self.st, self.axis, self.antindex, self.refantindex, timeslot=self.timeslot, direction=diridx, weights=self.weights)
            if self.timeslot < (len(self.times) - 1) and (not self.button_prev.isEnabled()):
                self.button_prev.setEnabled(True)
            if self.timeslot == (len(self.times) - 1):
                self.button_next.setEnabled(False)
        self.plot(x, y, yw, self.frametitle, ax_labels=[self.xlabel, self.ylabel], plot_labels=l, isphase=p)

    def plot(self, xaxis, yaxis, yaxis_weight, frametitle='', limits=[None, None], ax_labels=['', ''], plot_labels=[], multidim=False, isphase=False):
        self.xlabel = ax_labels[0]
        self.ylabel = ax_labels[1]
        self.fig.clf()
        self.ax = self.fig.add_subplot(111)
        self.ax.clear()
        if 'time' in ax_labels[0]:
            # Start counting from t=0
            xaxis = xaxis - xaxis[0]
        self.ax.set_title(frametitle + ' - {:s}'.format(self.direction))
        if 'time' in ax_labels[0] and self.do_timediff:
            self.ax.set_title(frametitle + ' - {:s} time diff'.format(self.direction))
        if 'freq' in ax_labels[0] and self.do_freqdiff:
            self.ax.set_title(frametitle + ' - {:s} frequency diff'.format(self.direction))
        if self.ax.get_legend_handles_labels()[1]:
            self.ax.legend()
        if type(xaxis) is list:
            xaxis = np.asarray(xaxis)
        if type(yaxis) is list:
            yaxis = np.asarray(yaxis)
        if type(yaxis_weight) is list:
            yaxis_weight = np.asarray(yaxis_weight)
            # Set weights to 0 for NaN solutions.
            yaxis_weight[np.isnan(yaxis)] = 0
        if len(yaxis.shape) > 1 and len(plot_labels) != 0:
            if self.do_poldiff:
                # Need to plot polarisation difference.
                v = yaxis[0, :] - yaxis[-1, :]
                # Handle flagged data. Weights are polarization independent.
                vw = yaxis_weight[0, :]
                v_m = np.ma.masked_where(vw == 0, v)
                if 'time' in ax_labels[0] and self.do_timediff:
                    v_m = np.ma.diff(v_m, axis=0)
                    xaxis = xaxis[:-1]
                    vw = vw[:-1]
                    if self.st.getType() == 'phase':
                        v_m = wrap_phase(v_m)
                if 'freq' in ax_labels[0] and self.do_freqdiff:
                    v_m = np.diff(v_m, axis=0)
                    xaxis = xaxis[:-1]
                    vw = vw[:-1]
                    if self.st.getType() == 'phase':
                        v_m = wrap_phase(v_m)
                self.ax.plot(xaxis, v_m, '--', alpha=0.25, color='C0')
                self.ax.plot(xaxis, v_m, 'h', label='XX - YY', color='C0')
                if np.ma.is_masked(v_m):
                    for x in xaxis[vw==0]:
                        self.ax.axvline(x, color='r')
                    self.ax.plot(xaxis[np.ma.getmaskarray(v_m)], np.ma.getdata(v_m)[np.ma.getmaskarray(v_m)], 'h', color='r')
            else:
                for i in range(yaxis.shape[0]):
                    if self.mode == 'values':
                        v = yaxis[i, :]
                        # Handle flagged data. Weights are polarization independent.
                        vw = yaxis_weight[i, :]
                        v_m = np.ma.masked_where(vw == 0, v)
                        if 'time' in ax_labels[0] and self.do_timediff:
                            v_m = np.ma.diff(v_m, axis=0)
                            if i == 0:
                                xaxis = xaxis[:-1]
                            vw = vw[:-1]
                            if self.st.getType() == 'phase':
                                v_m = wrap_phase(v_m)
                        if 'freq' in ax_labels[0] and self.do_freqdiff:
                            v_m = np.diff(v_m, axis=0)
                            if i == 0:
                                xaxis = xaxis[:-1]
                            vw = vw[:-1]
                            if self.st.getType() == 'phase':
                                v_m = wrap_phase(v_m)
                        self.ax.plot(xaxis, v_m, '--', alpha=0.25, color='C' + str(i))
                        self.ax.plot(xaxis, v_m, 'h', label=plot_labels[i], color='C' + str(i))
                        if np.ma.is_masked(v_m):
                            for x in xaxis[vw==0]:
                                self.ax.axvline(x, color='r')
                            self.ax.plot(xaxis[np.ma.getmaskarray(v_m)], np.ma.getdata(v_m)[np.ma.getmaskarray(v_m)], 'h', color='r')
                    elif self.mode == 'weights':
                        vw = yaxis_weight[i, :]
                        self.ax.plot(xaxis, vw, '--', alpha=0.25, color='C' + str(i))
                        self.ax.plot(xaxis, vw, 'h', label=plot_labels[i], color='C' + str(i))
            self.ax.legend()
        elif len(yaxis.shape) > 1 and len(plot_labels) == 0:
            if self.do_poldiff:
                # Need to plot polarisation difference.
                v = yaxis[0, :] - yaxis[-1, :]
                # Handle flagged data. Weights are polarization independent.
                vw = yaxis_weight[0, :]
                v_m = np.ma.masked_where(vw == 0, v)
                if 'time' in ax_labels[0] and self.do_timediff:
                    v_m = np.ma.diff(v_m, axis=0)
                    if i == 0:
                        xaxis = xaxis[:-1]
                    vw = vw[:-1]
                    if self.st.getType() == 'phase':
                        v_m = wrap_phase(v_m)
                if 'freq' in ax_labels[0] and self.do_freqdiff:
                    v_m = np.diff(v_m, axis=0)
                    if i == 0:
                        xaxis = xaxis[:-1]
                    vw = vw[:-1]
                    if self.st.getType() == 'phase':
                        v_m = wrap_phase(v_m)
                self.ax.plot(xaxis, v_m, '--', alpha=0.25, color='C0')
                self.ax.plot(xaxis, v_m, 'h', label='XX - YY', color='C0')
                if np.ma.is_masked(v_m):
                    for x in xaxis[vw==0]:
                        self.ax.axvline(x, color='r')
                    self.ax.plot(xaxis[np.ma.getmaskarray(v_m)], np.ma.getdata(v_m)[np.ma.getmaskarray(v_m)], 'h', color='r')
            else:
                for i in range(yaxis.shape[0]):
                    if self.mode == 'values':
                        v = yaxis[i, :]
                        vw = yaxis_weight[i, :]
                        v_m = np.ma.masked_where(yaxis_weight == 0, v)
                        if 'time' in ax_labels[0] and self.do_timediff:
                            v_m = np.diff(v_m, axis=0)
                            if i == 0:
                                xaxis = xaxis[:-1]
                            vw = vw[:-1]
                            if self.st.getType() == 'phase':
                                v_m = wrap_phase(v_m)
                        if 'freq' in ax_labels[0] and self.do_freqdiff:
                            v_m = np.diff(v_m, axis=0)
                            if i == 0:
                                xaxis = xaxis[:-1]
                            vw = vw[:-1]
                            if self.st.getType() == 'phase':
                                v_m = wrap_phase(v_m)
                        self.ax.plot(xaxis, v_m, '--', alpha=0.25, color='C' + str(i))
                        self.ax.plot(xaxis, v_m, 'h', color='C' + str(i))
                    elif self.mode == 'weights':
                        vw = yaxis_weight[i, :]
                        self.ax.plot(xaxis, vw, '--', alpha=0.25, color='C' + str(i))
                        self.ax.plot(xaxis, vw, 'h', label=plot_labels[i], color='C' + str(i))
        else:
            if self.mode == 'values':
                v = yaxis
                vw = yaxis_weight
                v_m = np.ma.masked_where(yaxis_weight == 0, v)
                if 'time' in ax_labels[0] and self.do_timediff:
                    v_m = np.diff(v_m, axis=0)
                    xaxis = xaxis[:-1]
                    vw = vw[:-1]
                    if self.st.getType() == 'phase':
                        v_m = wrap_phase(v_m)
                if 'freq' in ax_labels[0] and self.do_freqdiff:
                    v_m = np.diff(v_m, axis=0)
                    xaxis = xaxis[:-1]
                    vw = vw[:-1]
                    if self.st.getType() == 'phase':
                        v_m = wrap_phase(v_m)
                self.ax.plot(xaxis, v_m, '--', alpha=0.25, color='C0')
                self.ax.plot(xaxis, v_m, 'h', color='C0')
                if np.ma.is_masked(v_m):
                    for x in xaxis[vw==0]:
                        self.ax.axvline(x, color='r')
                    self.ax.plot(xaxis[np.ma.getmaskarray(v_m)], np.ma.getdata(v_m)[np.ma.getmaskarray(v_m)], 'h', color='r')
            elif self.mode == 'weights':
                    vw = yaxis_weight
                    self.ax.plot(xaxis, vw, '--', alpha=0.25, color='C0')
                    self.ax.plot(xaxis, vw, 'h', label=plot_labels, color='C0')
        if isphase:
            self.ax.set_ylim(-np.pi, np.pi)
        self.ax.set(xlabel=ax_labels[0], ylabel=ax_labels[1], xlim=limits[0], ylim=limits[1])
        self.canvas.draw()


class GraphWindow2D(QDialog):
    """ A window displaying the plotted 2D quantity. Allows the user to cycle through antenna.
    """
    def __init__(self, values, weights, frametitle, antindex, refantindex, axis, st, polslot=0, direction=0, times=None, freqs=None, pols=None, parent=None, mode='values', do_timediff=False, do_freqdiff=False, do_poldiff=False):
        """ Initialize a new GraphWindow instance.

        Args:
            values (ndarray): array of values to plot.
            weights (ndarray): array of weights corresponding to the values.
            frametitle (str): title the frame will have.
            antindex (int): the index of the selected antenna.
            polslot (int): index along the polarization axis to start with.
            direction (int): index of the direction to plot.
            parent (QDialog): parent window instance.
        Returns:
            None
        """
        super(GraphWindow2D, self).__init__()
        self.setWindowFlags(QtCore.Qt.WindowSystemMenuHint | QtCore.Qt.WindowMinMaxButtonsHint | QtCore.Qt.WindowCloseButtonHint)
        # Set up for logging output.
        self.LOGGER = logging.getLogger('GraphWindow')
        # LOGGER.setLevel(logging.INFO)
        self.LOGGER.setLevel(logging.DEBUG)

        self.frametitle = frametitle
        self.axis = axis
        self.polslot = polslot
        self.direction = direction
        self.values = values
        self.weights = weights
        self.antindex = antindex
        self.refantindex = refantindex
        self.st = st
        self.parent = parent
        self.mode = mode
        self.do_timediff = do_timediff
        self.do_freqdiff = do_freqdiff
        self.do_poldiff = do_poldiff
        try:
            self.polarizations = pols
        except AttributeError:
            pass

        try:
            self.frequencies = freqs
        except AttributeError:
            # frequencies is None, plotting against time.
            pass

        try:
            self.times = times
            self.times -= self.times[0]
        except AttributeError:
            # times is None, plotting against time.
            pass

        formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
        self.LOGFILEH = logging.FileHandler('h5plot.log')
        self.LOGFILEH.setLevel(logging.DEBUG)
        self.LOGFILEH.setFormatter(formatter)
        self.LOGGER.addHandler(LOGFILEH)

        self.setWindowTitle(frametitle)

        self.button_next = QPushButton('Forward')
        self.button_next.clicked.connect(self._forward_button_event)
        self.button_prev = QPushButton('Back')
        self.button_prev.clicked.connect(self._backward_button_event)
        self.button_prev.setEnabled(False)
        try:
            if self.st.getType() not in ['rotation', 'tec']:
                self.select_label = QLabel('Corr.: {:s}'.format(self.polarizations[polslot]))
                if len(self.polarizations) == 1:
                    self.button_next.setEnabled(False)
                else:
                    self.button_next.setEnabled(True)
            else:
                raise TypeError
        except TypeError:
            # No pol axis.
            self.select_label = QLabel('')
            self.button_next.setEnabled(False)
            self.button_prev.setEnabled(False)
        self.select_label.setAlignment(QtCore.Qt.AlignCenter)

        self.btn_antiter_next = QPushButton('Next antenna')
        self.btn_antiter_next.clicked.connect(self._antiter_next_button_event)
        self.btn_antiter_prev = QPushButton('Previous antenna')
        self.btn_antiter_prev.clicked.connect(self._antiter_prev_button_event)

        self.btn_diriter_next = QPushButton('Next direction')
        self.btn_diriter_next.clicked.connect(self._diriter_next_button_event)
        self.btn_diriter_prev = QPushButton('Previous direction')
        self.btn_diriter_prev.clicked.connect(self._diriter_prev_button_event)

        antiter_widget = QWidget()
        antiter_layout = QHBoxLayout(antiter_widget)
        antiter_layout.addWidget(self.btn_antiter_prev)
        antiter_layout.addWidget(self.btn_antiter_next)
        antiter_layout.addWidget(self.btn_diriter_prev)
        antiter_layout.addWidget(self.btn_diriter_next)

        self.buttons = QGridLayout()
        self.buttons.addWidget(self.button_prev, 0, 0)
        self.buttons.addWidget(self.select_label, 0, 1)
        self.buttons.addWidget(self.button_next, 0, 2)

        self.fig = Figure(figsize=(16,9))
        self.canvas = FigureCanvas(self.fig)
        self.toolbar = NavigationToolbar(self.canvas, self)

        self.layout = QVBoxLayout()
        self.layout.addWidget(self.toolbar)
        self.layout.addWidget(self.canvas, stretch=500)
        self.layout.addItem(self.buttons)
        self.layout.addWidget(antiter_widget)
        self.setLayout(self.layout)

        if self.do_poldiff:
            self.button_prev.setEnabled(False)
            self.button_next.setEnabled(False)
            self.select_label.setText('Corr.: {:s} - {:s}'.format(self.polarizations[0], self.polarizations[-1]))
            return

    def format_time(self, seconds):
        """ Formats the time to be displayed in the plotting windows.

        A string is formatted, displaying the time in seconds or (fractional) minutes or hours.

        Args:
            seconds (int): the time in seconds.
        Returns:
            formatted time (str): formatted time string.
        """
        if seconds < 60:
            return '{:.3f} sec'.format(seconds)
        elif 60 <= seconds < 3600:
            return '{:.3f} min'.format(seconds / 60)
        elif seconds >= 3600:
            return '{:.3f} hr'.format(seconds / 3600)
        else:
            return '{:.3f}'.format(seconds)

    def _antiter_next_button_event(self):
        if (self.antindex + 1) < len(self.parent.stations):
            self.antindex += 1
        else:
            self.antindex = 0
        diridx = self.parent.directions.index(self.direction)
        if len(self.polarizations) > 1 and self.do_poldiff:
            # Need to plot polarisation difference.
            x, y, z1, zw1, p = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, 0, direction=diridx)
            x2, y2, z2, zw2, p2 = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, -1, direction=diridx)
            z = z1 - z2
            # Combine flags of both polarisations.
            zw = zw1 * zw2
        else:
            x, y, z, zw, p = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, self.polslot, direction=diridx)
        self.frametitle = self.parent.stations[self.antindex]
        if self.mode == 'values':
            self.plot(x, y, z, self.frametitle, ax_labels=('Time [s]', 'Freq. [MHz]'), isphase=p)
        elif self.mode == 'weights':
            self.plot(x, y, zw, self.frametitle, ax_labels=('Time [s]', 'Freq. [MHz]'))
        self.setWindowTitle(self.frametitle)

    def _antiter_prev_button_event(self):
        if (self.antindex - 1) > 0:
            self.antindex -= 1
        else:
            self.antindex = len(self.parent.stations) - 1
        diridx = self.parent.directions.index(self.direction)
        if len(self.polarizations) > 1 and self.do_poldiff:
            # Need to plot polarisation difference.
            x, y, z1, zw1, p = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, 0, direction=diridx)
            x2, y2, z2, zw2, p2 = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, -1, direction=diridx)
            z = z1 - z2
            # Combine flags of both polarisations.
            zw = zw1 * zw2
        else:
            x, y, z, zw, p = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, self.polslot, direction=diridx)
        self.frametitle = self.parent.stations[self.antindex]
        if self.mode == 'values':
            self.plot(x, y, z, self.frametitle, ax_labels=('Time [s]', 'Freq. [MHz]'), isphase=p)
        elif self.mode == 'weights':
            self.plot(x, y, zw, self.frametitle, ax_labels=('Time [s]', 'Freq. [MHz]'))
        self.setWindowTitle(self.frametitle)

    def _diriter_next_button_event(self):
        diridx = self.parent.directions.index(self.direction)
        if (diridx + 1) < len(self.parent.directions):
            diridx += 1
        else:
            diridx = 0
        self.parent.direction = diridx
        self.direction = self.parent.directions[diridx]
        if len(self.polarizations) > 1 and self.do_poldiff:
            # Need to plot polarisation difference.
            x, y, z1, zw1, p = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, 0, direction=diridx)
            x2, y2, z2, zw2, p2 = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, -1, direction=diridx)
            z = z1 - z2
            # Combine flags of both polarisations.
            zw = zw1 * zw2
        else:
            x, y, z, zw, p = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, self.polslot, direction=diridx)
        self.frametitle = self.parent.stations[self.antindex]
        if self.gridplot:
            self.plot_all()
        else:
            if self.mode == 'values':
                self.plot(x, y, z, self.frametitle, ax_labels=('Time [s]', 'Freq. [MHz]'), isphase=p)
            elif self.mode == 'weights':
                self.plot(x, y, zw, self.frametitle, ax_labels=('Time [s]', 'Freq. [MHz]'))
        self.setWindowTitle(self.frametitle)

    def _diriter_prev_button_event(self):
        diridx = self.parent.directions.index(self.direction)
        if (diridx - 1) > 0:
            diridx -= 1
        else:
            diridx = -1
        self.parent.direction = diridx#self.parent.directions[diridx]
        self.direction = self.parent.directions[diridx]
        if len(self.polarizations) > 1 and self.do_poldiff:
            # Need to plot polarisation difference.
            x, y, z1, zw1, p = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, 0, direction=diridx)
            x2, y2, z2, zw2, p2 = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, -1, direction=diridx)
            z = z1 - z2
            # Combine flags of both polarisations.
            zw = zw1 * zw2
        else:
            x, y, z, zw, p = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, self.polslot, direction=diridx)
        self.frametitle = self.parent.stations[self.antindex]
        if self.gridplot:
            self.plot_all()
        else:
            if self.mode == 'values':
                self.plot(x, y, z, self.frametitle, ax_labels=('Time [s]', 'Freq. [MHz]'), isphase=p)
            elif self.mode == 'weights':
                self.plot(x, y, zw, self.frametitle, ax_labels=('Time [s]', 'Freq. [MHz]'))
        self.setWindowTitle(self.frametitle)

    def _forward_button_event(self):
        """ An event triggered by pressing the "Forward" button of a GraphWindow.

        When pressed, the abscissa is advanced one position, showing the next time or frequency slot.
        """
        if (self.polslot < (len(self.polarizations) - 1)):
            self.polslot += 1
        if (self.polslot > 0) and (not self.button_prev.isEnabled()):
            self.button_prev.setEnabled(True)
        if self.polslot == (len(self.polarizations) - 1):
            self.button_next.setEnabled(False)
        self.select_label.setText('Corr.: {:s}'.format(self.polarizations[self.polslot]))
        diridx = self.parent.directions.index(self.direction)
        if self.gridplot:
            self.plot_all()
        else:
            x, y, z, zw, p = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, self.polslot, direction=diridx)
            if self.mode == 'values':
                self.plot(x, y, z, self.frametitle, ax_labels=('Time [s]', 'Freq. [MHz]'), isphase=p)
            elif self.mode == 'weights':
                self.plot(x, y, zw, self.frametitle, ax_labels=('Time [s]', 'Freq. [MHz]'))

    def _backward_button_event(self):
        """ An event triggered by pressing the "Back" button of a GraphWindow.

        When pressed, the abscissa is set back one position, showing the previous time or frequency slot.
        """
        if (self.polslot > 0):
            self.polslot -= 1
        if self.polslot == 0:
            self.button_prev.setEnabled(False)
        if (self.polslot < (len(self.polarizations) - 1)) and (not self.button_next.isEnabled()):
            self.button_next.setEnabled(True)
        self.select_label.setText('Corr.: {:s}'.format(self.polarizations[self.polslot]))
        diridx = self.parent.directions.index(self.direction)
        if self.gridplot:
            self.plot_all()
        else:
            x, y, z, zw, p = load_axes_2d(self.values, self.weights, self.st, self.antindex, self.refantindex, self.polslot, direction=diridx)
            if self.mode == 'values':
                self.plot(x, y, z, self.frametitle, ax_labels=('Time [s]', 'Freq. [MHz]'), isphase=p)
            elif self.mode == 'weights':
                self.plot(x, y, zw, self.frametitle, ax_labels=('Time [s]', 'Freq. [MHz]'))

    def plot_all(self):
        self.fig.clf()
        self.gridplot = True
        from matplotlib.gridspec import GridSpec
        import matplotlib.patheffects as pe
        Nplots = len(self.parent.stations)
        Ncols = int(np.ceil(np.sqrt(Nplots)))
        Nrows = int(np.ceil(float(Nplots)/Ncols))
        # plot_window, axes = plt.subplots(Ncols, Nrows, figsize=(16, 9), constrained_layout=True)
        gs = GridSpec(Nrows, Ncols, figure=self.fig, left=0.05, right=0.95, bottom=0.05, top=0.95, hspace=0.04, wspace=0.025)
        Nlim = len(self.parent.stations)
        self.fig.suptitle(self.st.name + ' - {:s}'.format(self.direction))
        Nplotted = 0
        for i in range(Nrows):
            for j in range(Ncols):
                if Nplotted >= Nlim:
                    continue
                ax = self.fig.add_subplot(gs[i, j])
                ax.set_title(self.parent.stations[Nplotted], y=1.0, pad=-15, color='white', fontweight='bold', path_effects=[pe.withStroke(linewidth=1, foreground="black")])
                x = self.parent.stcache.values[1]['time']
                y = self.parent.stcache.values[1]['freq'] / 1e6
                if self.mode == 'values':
                    if self.st.getType() == 'phase':
                        current_cmap = matplotlib.cm.get_cmap('jet')
                        current_cmap.set_bad(color='w')
                        if ('pol' in self.parent.stcache.axes) and ('dir' in self.parent.stcache.axes):
                            if self.do_poldiff:
                                vals = wrap_phase(self.parent.stcache.values[0][:, :, 0, self.parent.direction] - self.parent.stcache.values[0][:, :, -1, self.parent.direction]).T
                            else:
                                vals = (self.parent.stcache.values[0][:, :, Nplotted, self.polslot, self.parent.direction] - self.parent.stcache.values[0][:, :, self.refantindex, self.polslot, self.parent.direction]).T
                        elif not ('pol' in self.parent.stcache.axes) and ('dir' in self.parent.stcache.axes):
                            vals = (self.parent.stcache.values[0][:, :, Nplotted, self.parent.direction] - self.parent.stcache.values[0][:, :, self.refantindex, self.parent.direction]).T
                        elif ('pol' in self.parent.stcache.axes) and not ('dir' in self.parent.stcache.axes):
                            if self.do_poldiff:
                                vals = wrap_phase(self.parent.stcache.values[0][:, :, 0] - self.parent.stcache.values[0][:, :, -1]).T
                            else:
                                vals = (self.parent.stcache.values[0][:, :, Nplotted, self.polslot] - self.parent.stcache.values[0][:, :, self.refantindex, self.polslot]).T
                        elif not ('pol' in self.parent.stcache.axes) and not ('dir' in self.parent.stcache.axes):
                            vals = (self.parent.stcache.values[0][:, :, Nplotted] - self.parent.stcache.values[0][:, :, self.refantindex]).T
                        if self.do_timediff:
                            vals = wrap_phase(np.diff(vals, axis=0))
                        if self.do_freqdiff:
                            vals = wrap_phase(np.diff(vals, axis=1))
                        im = ax.imshow(vals, aspect='auto', extent=[0, x[-1]-x[0], y[0], y[-1]], interpolation='none', cmap=current_cmap)
                        im.set_clim(-np.pi, np.pi)
                    else:
                        current_cmap = matplotlib.cm.get_cmap('viridis')
                        current_cmap.set_bad(color='w')
                        if ('pol' in self.parent.stcache.axes) and ('dir' in self.parent.stcache.axes):
                            if self.do_poldiff:
                                vals = (self.parent.stcache.values[0][:, :, Nplotted, 0, self.parent.direction] - self.parent.stcache.values[0][:, :, Nplotted, -1, self.parent.direction]).T
                            else:
                                vals = self.parent.stcache.values[0][:, :, Nplotted, self.polslot, self.parent.direction].T
                        elif not ('pol' in self.parent.stcache.axes) and ('dir' in self.parent.stcache.axes):
                            vals = self.parent.stcache.values[0][:, :, Nplotted, self.parent.direction].T
                        elif ('pol' in self.parent.stcache.axes) and not ('dir' in self.parent.stcache.axes):
                            if self.do_poldiff:
                                vals = (self.parent.stcache.values[0][:, :, Nplotted, 0] - self.parent.stcache.values[0][:, :, Nplotted, -1]).T
                            else:
                                vals = self.parent.stcache.values[0][:, :, Nplotted, self.polslot].T
                        elif not ('pol' in self.parent.stcache.axes) and not ('dir' in self.parent.stcache.axes):
                            vals = self.parent.stcache.values[0][:, :, Nplotted].T
                        if self.do_timediff:
                            vals = np.diff(vals, axis=0)
                        if self.do_freqdiff:
                            vals = np.diff(vals, axis=1)
                        im = ax.imshow(vals, aspect='auto', extent=[0, x[-1]-x[0], y[0], y[-1]], interpolation='none', cmap=current_cmap, origin='lower', vmin=0.5, vmax=2.0)
                elif self.mode == 'weights':
                    current_cmap = matplotlib.cm.get_cmap('viridis')
                    if ('pol' in self.parent.stcache.axes) and ('dir' in self.parent.stcache.axes):
                        vals = self.parent.stcache.weights[:, :, Nplotted, self.polslot, self.parent.direction].T
                    elif not ('pol' in self.parent.stcache.axes) and ('dir' in self.parent.stcache.axes):
                        vals = self.parent.stcache.weights[:, :, Nplotted, self.parent.direction].T
                    elif not ('pol' in self.parent.stcache.axes) and not ('dir' in self.parent.stcache.axes):
                        vals = self.parent.stcache.weights[:, :, Nplotted, self.polslot].T
                    im = ax.imshow(vals, aspect='auto', extent=[0, x[-1]-x[0], y[0], y[-1]], interpolation='none', cmap=current_cmap, origin='lower')
                try:
                    self.fig.supxlabel('Time [s]')
                    self.fig.supylabel('Frequency [MHz]', x=0.01)
                except AttributeError:
                    self.LOGGER.warning('Setting labels with supxlabel/supylabel failed, not plotting axes labels. You may need to updated matplotlib to >3.4.')
                if (Nplotted % Ncols):
                    ax.yaxis.set_visible(False)
                if (Nplotted < len(self.parent.stations) - Ncols):
                    ax.xaxis.set_visible(False)
                Nplotted += 1
        cbar_ax = self.fig.add_axes([0.955, 0.15, 0.015, 0.8])
        self.fig.colorbar(im, cax=cbar_ax)
        self.fig.tight_layout()
        self.setWindowState(QtCore.Qt.WindowMaximized)
        self.btn_antiter_prev.setEnabled(False)
        self.btn_antiter_next.setEnabled(False)
        self.canvas.draw()


    def plot(self, xaxis, yaxis, zaxis, frametitle='', limits=[None, None], ax_labels=['', ''], multidim=False, isphase=False):
        self.gridplot = False
        self.xlabel = ax_labels[0]
        self.ylabel = ax_labels[1]
        self.fig.clf()
        self.ax = self.fig.add_subplot(111)
        self.ax.clear()
        # Start counting from t=0
        xaxis = xaxis - xaxis[0]
        yaxis = yaxis * 1e-6
        if self.polarizations is not None:
            if self.do_poldiff:
                pol = '{:s} - {:s}'.format(self.polarizations[0], self.polarizations[-1])
            else:
                pol = self.polarizations[self.polslot]
            self.ax.set_title(frametitle + ':' + self.st.name + ' - {:s} - {:s}'.format(pol, self.direction))
            if self.do_timediff and not self.do_freqdiff:
                self.ax.set_title(frametitle + ':' + self.st.name + ' - {:s} - {:s} time diff'.format(pol, self.direction))
            if not self.do_timediff and self.do_freqdiff:
                self.ax.set_title(frametitle + ':' + self.st.name + ' - {:s} - {:s} frequency diff'.format(pol, self.direction))
            if self.do_timediff and self.do_freqdiff:
                self.ax.set_title(frametitle + ':' + self.st.name + ' - {:s} - {:s} time and frequency diff'.format(pol, self.direction))
        else:
            self.ax.set_title(frametitle + ':' + self.st.name + ' - {:s}'.format(self.direction))
        tdiff = np.diff(xaxis)
        fdiff = np.diff(yaxis)
        solint_time = tdiff[0]
        solint_freq = fdiff[0]
        vals = zaxis.copy()
        plot_x = xaxis.copy()
        plot_y = yaxis.copy()
        if not np.allclose(np.diff(yaxis), np.diff(yaxis)[0], rtol=0.01): # if not evenly spaced
            self.LOGGER.warning('Found irregular frequency axis, padding gaps with NaNs...')
            gapidx = np.nanargmax(np.abs(np.diff(plot_y)) > 2*solint_freq)  # index BEFORE the gap
            gapsize = plot_y[gapidx+1] - plot_y[gapidx]
            y_spacing = solint_freq
            while (gapsize >= 2*y_spacing) and (len(plot_y) <= yaxis[-1] / solint_freq):
                vals = np.insert(vals, gapidx + 1, np.nan * np.ones((int(gapsize / y_spacing), vals.shape[0])), axis=1)  # fill gap with NaNs
                plot_y = np.insert(plot_y, gapidx + 1, np.arange(plot_y[gapidx], plot_y[gapidx + 1], solint_freq))  # fill gap with NaNs
                gapidx = np.nanargmax(np.abs(np.diff(plot_y)) >= 2*solint_freq)  # index BEFORE the gap
                gapsize = np.abs(plot_y[gapidx+1] - plot_y[gapidx])

        if not np.allclose(np.diff(xaxis), np.diff(xaxis)[0], rtol=0.01): # if not evenly spaced
            self.LOGGER.warning('Found irregular time axis, padding gaps with NaNs...')
            gapidx = np.nanargmax(np.abs(np.diff(plot_x)) > 2*solint_time)  # index BEFORE the gap
            gapsize = plot_x[gapidx+1] - plot_x[gapidx]
            x_spacing = solint_time
            while (gapsize >= 2*x_spacing) and (len(plot_x) <= xaxis[-1] / solint_time): # if gap greater 2 times y-spacing, pad NaNs
                vals = np.insert(vals, gapidx + 1, np.nan * np.ones((int(gapsize / x_spacing), vals.shape[1])), axis=0)  # fill gap with NaNs
                plot_x = np.insert(plot_x, gapidx + 1, np.arange(plot_x[gapidx], plot_x[gapidx + 1], solint_time))  # fill gap with NaNs
                gapidx = np.nanargmax(np.abs(np.diff(plot_x)) >= 2*solint_time)  # index BEFORE the gap
                gapsize = np.abs(plot_x[gapidx+1] - plot_x[gapidx])
        if self.parent.check_tdiff.isChecked():
            vals = np.diff(vals, axis=0)
            if self.st.getType() == 'phase':
                vals = wrap_phase(vals)
        if self.parent.check_fdiff.isChecked():
            vals = np.diff(vals, axis=1)
            if self.st.getType() == 'phase':
                vals = wrap_phase(vals)
        if isphase:
            current_cmap = matplotlib.cm.get_cmap('jet')
            current_cmap.set_bad(color='w')
            im = self.ax.imshow(vals.T, interpolation='none', extent=[plot_x[0], plot_x[-1], plot_y[0], plot_y[-1]], aspect='auto', cmap=current_cmap, origin='lower')
            im.set_clim(-np.pi, np.pi)
        else:
            current_cmap = matplotlib.cm.get_cmap('viridis')
            current_cmap.set_bad(color='w')
            im = self.ax.imshow(vals.T, interpolation='none', extent=[plot_x[0], plot_x[-1], plot_y[0], plot_y[-1]], aspect='auto', cmap=current_cmap, origin='lower')
        cb = self.fig.colorbar(im)
        self.ax.set(xlabel=ax_labels[0], ylabel=ax_labels[1], xlim=limits[0], ylim=limits[1])
        self.canvas.draw()


class ListWidget(QListWidget):
    """ Version of QListWidget that resizes itself.
    
    https://stackoverflow.com/questions/63497841/qlistwidget-does-not-resize-itself
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff)
        self.setSizeAdjustPolicy(QtWidgets.QAbstractScrollArea.AdjustToContents)


    def minimumSizeHint(self) -> QtCore.QSize:
        return QtCore.QSize(-1, -1)


    def viewportSizeHint(self) -> QtCore.QSize:
        if self.model().rowCount() == 0:
            return QtCore.QSize(self.width(), 0)
        height = sum(self.sizeHintForRow(i) for i in range(self.count()) if not self.item(i).isHidden())
        width = super().viewportSizeHint().width()
        return QtCore.QSize(width, height)


class H5PlotGUI(QWidget):
    """The main GUI for H5Plot.

    From here the SolSets, SolTabs and antennas to plot are selected.
    """
    def __init__(self, h5file, logging_instance, parent=None):
        """ Initialize a new instances of the H5PlotGUI.

        Args:
            h5file (str): name of the H5Parm to open.
            logging_instance (logging): an instance of the logging module to log to.
        Returns:
            None
        """
        super(H5PlotGUI, self).__init__(parent)
        self.logger = logging_instance
        self.figures = []

        self.h5parm = lh5.h5parm(h5file)
        self.solset_labels = self.h5parm.getSolsetNames()
        self.solset = self.h5parm.getSolset(self.solset_labels[0])

        self.soltab_labels = self.solset.getSoltabNames()
        self.soltab = self.solset.getSoltab(self.soltab_labels[0])

        for l in self.soltab_labels:
            try:
                self.frequencies = self.solset.getSoltab(l).getAxisValues('freq')
                break
            except TypeError:
                pass
        for l in self.soltab_labels:
            try:
                self.times = self.solset.getSoltab(l).getAxisValues('time')
                break
            except TypeError:
                pass
        for l in self.soltab_labels:
            try:
                if 'pol' in self.solset.getSoltab(l).getAxesNames():
                    self.polarizations = self.solset.getSoltab(l).getAxisValues('pol')
                break
            except TypeError:
                pass
        self.stations = self.soltab.getValues()[1]['ant']
        try:
            self.directions = [s.decode('utf-8') for s in self.solset.getSou().keys()]
        except AttributeError:
            # Probably normal string.
            self.directions = [s for s in self.solset.getSou().keys()]
        self.direction = 0
        self.refant = self.stations[0]
        self.wrapphase = True

        self.stcache = SoltabCache(self.soltab.getValues(), self.soltab.getAxesNames(), weights=self.soltab.getValues(weight=True)[0])
        rvals, rweights, raxes = reorder_soltab(self.soltab)
        self.stcache.update(rvals, raxes, weights=rweights)

        self.move(300, 300)
        self.setWindowTitle('H5Plot')

        self.solset_label = QLabel('SolSet: ')
        self.solset_picker = QComboBox()
        for l in self.solset_labels:
            self.solset_picker.addItem(l)
        self.solset_picker.activated.connect(self._solset_picker_event)

        self.soltab_label_y = QLabel('Plot ')
        self.soltab_label_x = QLabel(' vs ')
        self.soltab_picker = QComboBox()
        for l in self.soltab_labels:
            self.soltab_picker.addItem(l)
        self.soltab_picker.activated.connect(self._soltab_picker_event)
        self.axis_picker = QComboBox()
        self.axis_picker.addItems(['time', 'freq', 'waterfall'])
        self.axis_picker.activated.connect(self._axis_picker_event)
        self.axis = 'time'

        self.refant_label = QLabel('Ref. Ant. ')
        self.refant_picker = QComboBox()
        self.refant_picker.addItems(self.stations)
        self.refant_picker.activated.connect(self._refant_picker_event)

        # self.phasewrap_box = QCheckBox('Wrap Phases')
        # self.phasewrap_box.setChecked(True)
        # self.phasewrap_box.setEnabled(False)
        # self.phasewrap_box.stateChanged.connect(self._phasewrap_event)
        self.dir_label = QLabel('Dir.')
        self.dir_picker = QComboBox()
        self.dir_picker.addItems(self.directions)
        self.dir_picker.activated.connect(self._dir_picker_event)

        self.checkbox_layout = QGridLayout()
        self.check_weights = QCheckBox('Plot weights')
        self.check_tdiff = QCheckBox('Time diff.')
        self.check_fdiff = QCheckBox('Freq. diff.')
        self.check_pdiff = QCheckBox('Pol. diff. (XX-YY)')

        self.check_weights.toggled.connect(self._weight_picker_event)
        self.checkbox_layout.addWidget(self.check_weights, 0, 0)
        self.checkbox_layout.addWidget(self.check_tdiff, 0, 1)
        self.checkbox_layout.addWidget(self.check_fdiff, 1, 1)
        self.checkbox_layout.addWidget(self.check_pdiff, 1, 0)

        self.plotmode = 'values'

        self.plot_button = QPushButton('Plot')
        self.plot_button.clicked.connect(self._plot_button_event)

        self.plot_all_button = QPushButton('Plot all stations')
        self.plot_all_button.clicked.connect(self._plot_all_button_event)
        self.plot_all_button.setEnabled(False)

        self.station_picker = ListWidget()
        self.station_picker.addItems(self.stations)
        self.station_picker.setCurrentRow(0)

        plot_layout = QGridLayout()
        plot_layout.addWidget(self.soltab_label_y, 0, 0)
        plot_layout.addWidget(self.soltab_picker, 0, 1)
        plot_layout.addWidget(self.soltab_label_x, 0, 2)
        plot_layout.addWidget(self.axis_picker, 0, 3)
        plot_layout.addWidget(self.refant_label, 1, 0)
        plot_layout.addWidget(self.refant_picker, 1, 1)
        plot_layout.addWidget(self.dir_label, 1, 2,)
        plot_layout.addWidget(self.dir_picker, 1, 3)

        layout = QFormLayout(self)
        layout.addRow(self.solset_label, self.solset_picker)
        layout.addRow(plot_layout)
        layout.addRow(self.checkbox_layout)
        layout.addRow(self.plot_button)
        layout.addRow(self.plot_all_button)
        layout.addRow(self.station_picker)

    def _axis_picker_event(self):
        """Callback function for when the x-axis is changed.

        Sets the `axis` attribute to the selected axis
        """
        self.logger.debug('Axis changed to: ' + self.axis_picker.currentText())
        self.axis = self.axis_picker.currentText()
        if self.axis != 'waterfall':
            self.plot_all_button.setEnabled(False)
        else:
            self.plot_all_button.setEnabled(True)

    def closeEvent(self, event):
        """ The event triggerd upon closing the main application window.
        """
        self.logger.info('Closing all open figures before exiting.')
        plt.close('all')
        for f in self.figures:
            f.close()
        event.accept()

    def _refant_picker_event(self):
        """ An even triggered when a new reference antenna is selected.

        Sets the `refant` attribute.
        """
        self.logger.debug('Reference antenna changed to: ' + self.refant_picker.currentText())
        self.refant = self.refant_picker.currentText()

    def _solset_picker_event(self):
        """Callback function for when the SolSet is changed.

        Sets the `solset` attribute.
        """
        self.logger.debug('Solset changed to: ' + self.solset_picker.currentText())
        self.solset = self.h5parm.getSolset(self.solset_picker.currentText())
        self.soltab_labels = self.solset.getSoltabNames()
        self.soltab_picker.clear()
        for l in self.soltab_labels:
            self.soltab_picker.addItem(l)
        self._soltab_picker_event()

    def _soltab_picker_event(self):
        """Callback function for when the SolTab is changed.

        Sets the `soltab` attribute.
        """
        self.logger.debug('Soltab changed to: ' + self.soltab_picker.currentText())
        self.soltab = self.solset.getSoltab(self.soltab_picker.currentText())
        stations_old = self.stations
        self.stations = self.soltab.getValues()[1]['ant']
        if not np.array_equiv(stations_old, self.stations):
            self.logger.debug('Number of stations changed, updating list.')
            # The list of stations has changed, update the list.
            self.station_picker.clear()
            self.station_picker.addItems(self.stations)
            self.refant_picker.clear()
            self.refant_picker.addItems(self.stations)
        try:
            self.frequencies = self.soltab.getAxisValues('freq')
        except TypeError:
            # Soltab probably has no frequency axis.
            pass
        rvals, rweights, raxes = reorder_soltab(self.soltab)
        self.stcache.update(rvals, raxes, weights=rweights)

    def _dir_picker_event(self):
        """Callback function for when the direction is changed.

        Sets the `direction` attribute.
        """
        self.logger.debug('Direction changed to: ' + self.dir_picker.currentText())
        self.direction = self.dir_picker.currentIndex()

    def _phasewrap_event(self):
        """ An even triggered upon switching phase wrapping on or off. (not yet implemented)
        """
        self.logger.debug('Phase wrapping changed to ' + str(self.phasewrap_box.isChecked()))
        self.wrapphase = self.phasewrap_box.isChecked()

    def _plot_button_event(self):
        """Callback function for when the plot button is pressed.

        Calls the `plot` function subsecquently.
        """
        self.logger.debug('Plotting button pressed.')
        if self.axis == 'freq' or self.axis == 'time':
            self.plot(labels=(self.axis, self.soltab.name), mode=self.plotmode)
        elif self.axis == 'waterfall':
            self.plot_waterfall(labels=('time', 'freq'), mode=self.plotmode)

    def _plot_all_button_event(self):
        """ Callback function for when the plot all stations button is pressed."""
        self.logger.debug('Plotting all stations button pressed.')
        if self.axis == 'freq' or self.axis == 'time':
            self.plot(labels=(self.axis, self.soltab.name), mode=self.plotmode, plot_all=True)
        elif self.axis == 'waterfall':
            self.plot_waterfall(labels=('time', 'freq'), mode=self.plotmode, plot_all=True)

    def _weight_picker_event(self):
        if self.check_weights.isChecked():
            self.plotmode = 'weights'
            self.check_pdiff.setEnabled(False)
            self.check_fdiff.setEnabled(False)
            self.check_tdiff.setEnabled(False)
        else:
            self.plotmode = 'values'
            self.check_pdiff.setEnabled(True)
            self.check_fdiff.setEnabled(True)
            self.check_tdiff.setEnabled(True)
        self.logger.info('Plotting {:s}'.format(self.plotmode))

    def plot_waterfall(self, labels=('x-axis', 'y-axis'), mode='values', plot_all=False):
        """ Show a two-dimensional waterfall plot of time vs. frequency.
        """
        if ('phase_offset') in self.soltab.name:
            self.logger.info('Phase-offset is scalar and cannot be plotted in 2D.')
        if (('rotationmeasure' in self.soltab.name) or ('RMextract' in self.soltab.name) or ('clock' in self.soltab.name) or ('faraday' in self.soltab.name) or ('tec' in self.soltab.name)):
            self.logger.info('Rotation Measure, clock, faraday or TEC cannot be plotted in 2D!')
            return
        self.logger.info('Plotting ' + self.soltab.name + \
                         ' for ' + self.solset.name)
        antenna = self.station_picker.currentRow()
        # Data loaded here is xaxis, yaxis, zaxis, isphase
        print('Loading data')
        try:
            #x, y, z, zw, p = msg
            if hasattr(self, "polarizations"):
                if len(self.polarizations) > 1 and self.check_pdiff.isChecked():
                    # Need to plot polarisation difference.
                    x, y, z1, zw1, p = load_axes_2d(self.stcache.values, self.stcache.weights, self.soltab, antenna=antenna, refantenna=int(np.argwhere(self.stations == self.refant)), pol=0, direction=self.direction)
                    x2, y2, z2, zw2, p2 = load_axes_2d(self.stcache.values, self.stcache.weights, self.soltab, antenna=antenna, refantenna=int(np.argwhere(self.stations == self.refant)), pol=-1, direction=self.direction)
                    z = z1 - z2
                    # Combine flags of both polarisations.
                    zw = zw1 * zw2
                else:
                    x, y, z, zw, p = load_axes_2d(self.stcache.values, self.stcache.weights, self.soltab, antenna=antenna, refantenna=int(np.argwhere(self.stations == self.refant)), pol=0, direction=self.direction)
            else:
                x, y, z, zw, p = load_axes_2d(self.stcache.values, self.stcache.weights, self.soltab, antenna=antenna, refantenna=int(np.argwhere(self.stations == self.refant)), pol=0, direction=self.direction)
        except ValueError:
            logging.error('Error loading 2D data!')
            return
        if (len(x) == 1) or (len(y) == 1):
            self.logger.info('Either time or frequency has only 1 entry, not plotting!')
            return
        # print('PLOTTING 2D WEIGHTS')
        try:
            plot_window = GraphWindow2D(self.stcache.values, self.stcache.weights, self.stations[antenna], antenna, int(np.argwhere(self.stations == self.refant)), self.axis, self.soltab, times=self.times, freqs=self.frequencies, pols=self.polarizations, parent=self, direction=self.directions[self.direction], mode=mode, do_timediff=self.check_tdiff.isChecked(), do_freqdiff=self.check_fdiff.isChecked(), do_poldiff=self.check_pdiff.isChecked())
        except AttributeError:
            # No polarizations most likely.
            plot_window = GraphWindow2D(self.stcache.values, self.stcache.weights, self.stations[antenna], antenna, int(np.argwhere(self.stations == self.refant)), self.axis, self.soltab, times=self.times, freqs=self.frequencies, pols=['N/A'], parent=self, direction=self.directions[self.direction], mode=mode, do_timediff=self.check_tdiff.isChecked(), do_freqdiff=self.check_fdiff.isChecked(), do_poldiff=self.check_pdiff.isChecked())
        self.figures.append(plot_window)
        if plot_all:
            plot_window.plot_all()
        else:
            if mode == 'values':
                plot_window.plot(x, y, z, ax_labels=('Time [s]', 'Freq. [MHz]'), isphase=p, frametitle=self.stations[antenna])
            elif mode == 'weights':
                plot_window.plot(x, y, zw, ax_labels=('Time [s]', 'Freq. [MHz]'), frametitle=self.stations[antenna])

        plot_window.show()

    def plot(self, labels=('x-axis', 'y-axis'), limits=([None, None], [None, None]), mode='values', plot_all=False):
        if ('phase_offset') in self.soltab.name:
            self.logger.info('Phase-offset is scalar and cannot be plotted in 2D.')
        self.logger.info('Plotting ' + self.soltab.name + ' vs ' + self.axis + \
                         ' for ' + self.solset.name)
        antenna = self.station_picker.currentRow()
        if (('rotationmeasure' in self.soltab.name) or ('RMextract' in self.soltab.name) or ('clock' in self.soltab.name) or ('faraday' in self.soltab.name) or ('tec' in self.soltab.name)) and (self.axis == 'freq'):
            self.logger.info('Rotation Measure or clock does not support frequency axis! Switch to time instead.')
            return
        msg = load_axes(self.stcache.values, self.soltab, self.axis, antenna=antenna, refantenna=int(np.argwhere(self.stations == self.refant)), direction=self.direction, weights=self.stcache.weights)
        try:
            x_axis, Y_AXIS, Y_AXIS_WEIGHT, plabels, isphase = msg
        except ValueError:
            # Requested combination not supported.
            return
        if 'freq' in self.soltab.getAxesNames():
            plot_window = GraphWindow(self.stcache.values, self.stcache.weights, self.stations[antenna], antenna, int(np.argwhere(self.stations == self.refant)), self.axis, self.soltab, times=self.times, freqs=self.frequencies, parent=self, direction=self.directions[self.direction], mode=mode, do_timediff=self.check_tdiff.isChecked(), do_freqdiff=self.check_fdiff.isChecked(), do_poldiff=self.check_pdiff.isChecked())
        else:
            # Probably TEC or another solution type with no frequency axis.
            plot_window = GraphWindow(self.stcache.values, self.stcache.weights, self.stations[antenna], antenna, int(np.argwhere(self.stations == self.refant)), self.axis, self.soltab, times=self.times, parent=self, direction=self.directions[self.direction], mode=mode, do_timediff=self.check_tdiff.isChecked(), do_freqdiff=self.check_fdiff.isChecked(), do_poldiff=self.check_pdiff.isChecked())
        self.figures.append(plot_window)
        plot_window.plot(x_axis, Y_AXIS, Y_AXIS_WEIGHT, self.stations[antenna], limits=[None, None], ax_labels=[self.axis, labels[1]], plot_labels=plabels, isphase=isphase)
        plot_window.show()

        # TEC does not have a frequency axis, so disable the button as well.
        if 'tec' in self.soltab.name:
            self.logger.debug('TEC solutions detected, disabling buttons.')
            plot_window.button_next.setEnabled(False)
            plot_window.button_prev.setEnabled(False)
        if self.axis.lower() == 'freq' and (len(self.times) == 1):
            plot_window.button_next.setEnabled(False)
            self.logger.debug('Single time slot detected, disabling buttons.')
        elif self.axis.lower() == 'freq' and (len(self.times) > 1):
            self.logger.debug('Multiple time slots detected, enabling buttons.')
            plot_window.button_next.setEnabled(True)
        if self.axis.lower() == 'time' and (self.frequencies is None or (len(self.frequencies) == 1)):
            plot_window.button_next.setEnabled(False)
            self.logger.debug('Single frequency slot detected, disabling buttons.')
        elif self.axis.lower() == 'time' and (len(self.frequencies) > 1):
            self.logger.debug('Multiple frequency slots detected, enabling buttons.')
            plot_window.button_next.setEnabled(True)


class SoltabCache:
    '''Simple class just to store temporarily reordered soltab data.'''
    def __init__(self, values, axes, weights=None):
        """ Initialize a new SoltabCache instance.

        Args:
            values (ndarray): values to cache.
            axes (ndarray): axes to store.
            weights (ndarray or None): weights to cache.
        Returns:
            None
        """
        self.values = values
        self.axes = axes
        self.weights = weights

    def update(self, nvalues, naxes, weights=None):
        """ Update the data in the cache.

        Args:
            nvalues (ndarray): new values to store in the cache.
            naxes (ndarray): new axes to store in the cache.
            weights (ndarray or None): new weghts to store in the cache.
        Returns:
            None
        """
        self.values = nvalues
        self.axes = naxes
        self.weights = weights


# Global helper functions.
def reorder_soltab(st):
    """ Reorder a soltab in the order H5plot expects.

    The expected order in the plotter is time, frequency, antenna, polarization, direction.

    Args:
        st (SolTab): soltab instance to reorder the axes of.
    Returns:
        st_new (tuple): tuple of (values, weights, axes) reodered to the expected order.
        order_new (ndarray): array containing the reordered order of the axes.
    """
    LOGGER.info('Reordering soltab ' + st.name)
    order_old = st.getAxesNames()
    if ('phase_offset' in st.name):
        LOGGER.info('Not reordering phase_offset as it is a single number.')
    if ('rotationmeasure' in st.name) or ('RMextract'in st.name) or ('clock' in st.name) or ('faraday' in st.name) or ('tec' in st.name and 'freq' not in order_old):
        order_new = ['time', 'ant']
    else:
        order_new = ['time', 'freq', 'ant']
    if 'pol' in order_old:
        order_new += ['pol']
    if 'dir' in order_old:
        order_new += ['dir']
    t1 = time.time()
    reordered = reorderAxes(st.getValues()[0], order_old, order_new)
    reordered_weights = reorderAxes(st.getValues(weight=True)[0], order_old, order_new)
    t2 = time.time()
    LOGGER.info('Reordering took {:f} seconds'.format(t2 - t1))
    reordered2 = {}
    for k in order_new:
        reordered2[k] = st.axes[k]
    st.axes = reordered2
    st.axesNames = order_new
    st_new = (reordered, st.getValues()[1])
    return st_new, reordered_weights, order_new


def wrap_phase(phase):
    """ Map phases to the range -pi, pi.

    The formula (phase + np.pi) % (2 * np.pi) - np.pi is used to map phases into a plottable range.

    Args:
        phase (ndarray): narray of phases to remap.
    Returns:
        wphase (ndarray): narray of remapped phases.
    """
    wphase = (phase + np.pi) % (2 * np.pi) - np.pi
    return wphase


if __name__ == '__main__':
    print('H5Plot ' + __version__)
    print('Author: Frits Sweijen (frits.sweijen@gmail.com)')
    print('Released under GPLv3\n')
    import argparse
    parser = argparse.ArgumentParser(description='Interactive plotter to explore LOFAR H5parm solutions.')
    parser.add_argument('FILENAME')
    args = parser.parse_args()

    H5FILE = lh5.h5parm(args.FILENAME, readonly=True)
    # Set up for logging output.
    LOGGER = logging.getLogger('H5plot_logger')
    if '--debug' in sys.argv:
        LOGGER.setLevel(logging.DEBUG)
    else:
        LOGGER.setLevel(logging.INFO)

    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    LOGFILEH = logging.FileHandler('h5plot.log')
    LOGFILEH.setLevel(logging.DEBUG)
    LOGFILEH.setFormatter(formatter)
    LOGGER.addHandler(LOGFILEH)
    LOGSTREAM = logging.StreamHandler()
    LOGSTREAM.setLevel(logging.DEBUG)
    LOGSTREAM.setFormatter(formatter)
    LOGGER.addHandler(LOGSTREAM)

    LOGGER.info('Successfully opened %s', args.FILENAME)

    SOLSETS = H5FILE.getSolsetNames()
    print('Found solset(s) ' + ', '.join(SOLSETS))
    for solset in SOLSETS:
        print('SolTabs in ' + solset + ':')
        ss = H5FILE.getSolset(solset)
        soltabs = ss.getSoltabNames()
        print('\t' + ', '.join(soltabs))

    # Initialize the GUI.
    APP = QApplication(sys.argv)
    GUI = H5PlotGUI(args.FILENAME, LOGGER)
    GUI.show()
    APP.exec_()

    H5FILE.close()
    LOGGER.info('%s successfully closed.', args.FILENAME)
