#!/usr/bin/env python

# Copyright (C) 2018 Patrick Godwin
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

__usage__ = "dtt2tfplot [--options] input_file_1 ... input_file_N"
__description__ = "an executable to generate transfer function plots from dtt XML files"
__author__ = "Patrick Godwin (patrick.godwin@ligo.org)"
__doc__ = "\n\n".join([__usage__, __description__, __author__])

#---------------------------
### imports

import os
from optparse import OptionParser

import numpy

import dtt2hdf

import matplotlib
matplotlib.use('Agg')
from matplotlib import pyplot as plt

matplotlib.rcParams.update({
	"font.size": 10.0,
	"axes.titlesize": 10.0,
	"axes.labelsize": 10.0,
	"xtick.labelsize": 8.0,
	"ytick.labelsize": 8.0,
	"legend.fontsize": 8.0,
	"figure.dpi": 300,
	"savefig.dpi": 300,
	"text.usetex": False,
	"path.simplify": True
})

#---------------------------
### functions

def plot_dtt(freq, transfer_fn, ch_name, grid):

    ### extract amplitude and phase
    transfer_fn_amp = numpy.absolute(transfer_fn)
    transfer_fn_phase = numpy.angle(transfer_fn, deg=True)

    fig, axes = plt.subplots(nrows=2, sharex=True)

    ### amplitude plot
    axes[0].loglog(freq, transfer_fn_amp, color = '#fc8d59', alpha = 0.9, ls = 'solid')

    axes[0].set_ylabel(r"Amplitude")
    axes[0].set_title(r"Transfer Function for Channel %s"%ch_name)
    axes[0].grid(which=grid)

    ### phase plot
    axes[1].semilogx(freq, transfer_fn_phase, color = '#fc8d59', alpha = 0.9, ls = 'solid')

    axes[1].set_xlabel(r"Frequency [Hz]")
    axes[1].set_ylabel(r"Phase [deg]")
    axes[1].set_ylim((-200, 200))
    axes[1].set_yticks([-180, -90, 0, 90, 180])
    axes[1].grid(which=grid)

    fig.tight_layout(pad = .8)

    return fig

#---------------------------
### main

if __name__ == "__main__":

    ### define CLI options
    parser = OptionParser(usage=__usage__, description=__description__)

    parser.add_option('-v', '--verbose', default=False, action='store_true',
        help='Be verbose.')
    parser.add_option('-o', '--output-dir', metavar = "filename", default=".",
        help='Set the output directory for plots.')
    parser.add_option('--grid', default="major",
        help='Sets the grid lines for plots, options are [major|minor|both]. Default="both".')

    options, args = parser.parse_args()

    ### create directory if it doesn't already exist
    try:
        os.makedirs(options.output_dir)
    except IOError:
        pass
    except OSError:
        pass

    assert len(args) > 0, 'please provide at least one input file\n%s'%__usage__

    if options.verbose:
        print("parsing files...")

    ### read in dtt XML files
    dtt_files = [(file_.split('.')[0], dtt2hdf.read_diaggui(file_)) for file_ in args]

    if options.verbose:
        print("creating plots...")

    for dtt_name, dtt_file in dtt_files:

        if options.verbose:
            print("    creating plot for channel %s ..."%dtt_name)

        ### parse dtt file
        exc_ch_name = [ch_name for ch_name in dtt_file['results']['COH'].keys() if 'EXC' in ch_name][0] ### assumes unique match
        in1_ch_name = [ch_name for ch_name in dtt_file['results']['COH'].keys() if 'IN1' in ch_name][0] ### assumes unique match
        in2_ch_name = [ch_name for ch_name in dtt_file['results']['COH'].keys() if 'IN2' in ch_name][0] ### assumes unique match
        channel = exc_ch_name.rsplit('_', 1)[0]
        ch_names = dtt_file['results']['CSD'].keys()
        dtt_data = dtt_file['results']['CSD']

        ### extract all CSDs as a dict in the form csds[channelA][channelB] == S_chA_chB
        csd_idxs = {ch_name: dict(dtt_data[ch_name]['channelB_inv']) for ch_name in ch_names}
        csds = {chA_name: {chB_name: dtt_data[chA_name]['CSD'][csd_idxs[chA_name][chB_name]] for chB_name in dtt_data[chA_name]['channelB_inv'].keys()} for chA_name in ch_names}
        #psds = {ch_name: dtt_file['results']['PSD'][ch_name]['PSD'][0] for ch_name in ch_names}

        ### extract frequency array
        f0 = dtt_file['results']['CSD'][ch_names[0]]['f0']
        df = dtt_file['results']['CSD'][ch_names[0]]['df']
        freq = f0 + numpy.arange(len(csds[exc_ch_name][in1_ch_name])) * df

        ### format dtt data for plotting
        transfer_fn_2e = csds[in2_ch_name][in1_ch_name].astype(numpy.complex128) / csds[exc_ch_name][in1_ch_name].astype(numpy.complex128)
        transfer_fn_open_loop = numpy.reciprocal(transfer_fn_2e) - 1

        ### generate and save plots
        fname = 'plot_%s.png'%channel
        dtt_fig = plot_dtt(freq, transfer_fn_open_loop, channel, options.grid)
        dtt_fig.savefig(os.path.join(options.output_dir, fname))
        plt.close(fname)

    if options.verbose:
        print("...done.")
