#!python
# coding=utf-8
# Copyright (C) LIGO Scientific Collaboration (2015-)
#
# This file is part of the GW DetChar python package.
#
# GW DetChar is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GW DetChar is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GW DetChar.  If not, see <http://www.gnu.org/licenses/>.

"""Search for evidence of beam scattering based on optic velocity
"""

from __future__ import division

import os.path
import re
import warnings

from six.moves import StringIO

import numpy

from matplotlib import (use, rcParams)
use('agg')  # noqa

import gwtrigfind

from gwpy.plot import Plot
from gwpy.plot.tex import label_to_latex
from gwpy.utils import gprint
from gwpy.table import EventTable
from gwpy.table.filters import in_segmentlist
from gwpy.segments import (DataQualityFlag, DataQualityDict,
                           Segment, SegmentList)

from gwdetchar import (cli, const, scattering)
from gwdetchar.io import html as htmlio
from gwdetchar.io.datafind import get_data
from gwdetchar.plot import get_gwpy_tex_settings

__author__ = 'Duncan Macleod <duncan.macleod@ligo.org>'

# TeX settings
tex_settings = get_gwpy_tex_settings()
rcParams.update(tex_settings)
rcParams.update({
    'axes.labelsize': 20,
    'figure.subplot.bottom': 0.17,
    'figure.subplot.left': 0.1,
    'figure.subplot.right': 0.9,
    'figure.subplot.top': 0.90,
    'grid.color': 'gray',
    'image.cmap': 'viridis',
    'svg.fonttype': 'none',
})

# -- parse command line -------------------------------------------------------

parser = cli.create_parser(description=__doc__)
cli.add_gps_start_stop_arguments(parser)
cli.add_ifo_option(parser)
parser.add_argument('-a', '--state-flag', metavar='FLAG', default=None,
                    help='restrict search to times when FLAG was active')
parser.add_argument('-t', '--frequency-threshold', type=float, default=40.,
                    help='critical fringe frequency threshold (in Hertz), '
                         'default: 10')
parser.add_argument('-x', '--multiplier-for-threshold', type=int,
                    default=4, choices=scattering.FREQUENCY_MULTIPLIERS,
                    help='fringe frequency multiplier to use when applying '
                         '--frequency-threshold, default: %(default)s')
parser.add_argument('-m', '--optic', action='append',
                    help='optic to search for scattering signal, can be given '
                         'multiple times, default: %s'
                         % scattering.OPTIC_MOTION_CHANNELS.keys())
parser.add_argument('-p', '--segment-padding', type=float, default=.05,
                    help='time with which to pad scattering segments on '
                         'either side, default: %(default)s')
parser.add_argument('-s', '--segment-start-pad', type=float, default=30.0,
                    help='amount of time to remove from the start of each '
                         'analysis segment')
parser.add_argument('-e', '--segment-end-pad', type=float, default=30.0,
                    help='amount of time to remove from the end of each '
                         'analysis segment')
parser.add_argument('-c', '--main-channel',
                    default='{IFO}:GDS-CALIB_STRAIN',
                    help='name of main (h(t)) channel, default: %(default)s')
cli.add_frametype_option(parser, default='%s_R' % const.IFO)
parser.add_argument('-o', '--output-dir', type=os.path.abspath,
                    default=os.curdir,
                    help='output directory for analysis, default: %(default)s')
parser.add_argument('-v', '--verbose', action='store_true', default=False,
                    help='print verbose output, default: %(default)s')
cli.add_nproc_option(parser)

args = parser.parse_args()

if args.frequency_threshold.is_integer():
    args.frequency_threshold = int(args.frequency_threshold)

tstr = str(args.frequency_threshold).replace('.', '_')
gpsstr = '%s-%s' % (int(args.gpsstart), int(args.gpsend-args.gpsstart))

if args.optic is None:
    args.optic = scattering.OPTIC_MOTION_CHANNELS.keys()

if not os.path.isdir(args.output_dir):
    os.makedirs(args.output_dir)
os.chdir(args.output_dir)

segfile = '%s-SCATTERING_SEGMENTS_%s_HZ-%s.h5' % (args.ifo, tstr, gpsstr)

# -- get state segments -------------------------------------------------------

span = Segment(args.gpsstart, args.gpsend)

# get segments
if args.state_flag is not None:
    state = DataQualityFlag.query(args.state_flag, int(args.gpsstart),
                                  int(args.gpsend),
                                  url=const.DEFAULT_SEGMENT_SERVER)
    padding = args.segment_start_pad + args.segment_end_pad
    for i, seg in enumerate(state.active):
        if abs(seg) > padding:
            state.active[i] = type(seg)(
                seg[0]+args.segment_start_pad, seg[1]-args.segment_end_pad)
        else:
            if args.verbose:
                gprint("Segment length {} shorter than padding length {}, "
                       "skipping this segment".format(abs(seg), padding))
    state.coalesce()
    statea = state.active
    livetime = float(abs(statea))
    if args.verbose:
        gprint("Downloaded %d segments for %s [%.2fs livetime]"
               % (len(statea), args.state_flag, livetime))
else:
    statea = SegmentList([span])

# -- load h(t) ----------------------------------------------------------------

args.main_channel = args.main_channel.format(IFO=args.ifo)
if args.verbose:
    gprint("Loading Omicron triggers for %s..." % args.main_channel, end=' ')

if args.gpsstart >= 1230336018:  # Jan 1 2019
    ext = "h5"
    names = ["time", "frequency", "snr"]
    read_kw = {
        "columns": names,
        "selection": [
            "frequency < {}".format(args.frequency_threshold * 2),
            ("time", in_segmentlist, statea),
        ],
        "format": "hdf5",
        "path": "triggers",
    }
else:
    try:
        from glue.ligolw import lsctables  # noqa: F401
    except ImportError as exc:
        exc.args = (
            "{!s}, `lscsoft-glue` is required to read LIGO_LW XML files",
        )
        raise
    ext = "xml.gz"
    names = ['peak', 'peak_frequency', 'snr']
    read_kw = {
        "columns": names,
        "selection": [
            "peak_frequency < {}".format(args.frequency_threshold * 2),
            ('peak', in_segmentlist, statea),
        ],
        "format": 'ligolw',
        "tablename": "sngl_burst",
    }

fullcache = []
for seg in statea:
    cache = gwtrigfind.find_trigger_files(
        args.main_channel, 'omicron', seg[0], seg[1], ext="h5",
    )
    if len(cache) == 0:
        warnings.warn("No Omicron triggers found for %s in segment [%d .. %d)"
                      % (args.main_channel, seg[0], seg[1]))
        continue
    fullcache.extend(cache)

# read triggers
if fullcache:
    trigs = EventTable.read(fullcache, nproc=args.nproc, **read_kw)
else:  # no files (no livetime?)
    trigs = EventTable(names=names)

highsnrtrigs = trigs[trigs['snr'] >= 8]
if args.verbose:
    gprint("%d read" % len(trigs))

# -- prepare HTML -------------------------------------------------------------

page = htmlio.new_bootstrap_page(title='%s Scattering' % args.ifo)
page.div(class_='container')
page.div(class_='page-header')
page.h1('%s scattering: %d-%d'
        % (args.ifo, int(args.gpsstart), int(args.gpsend)))
page.p("This analysis searched for evidence of beam scattering based on the "
       "velocity of optic motion. The fringe frequency is predicted using "
       "equation (3) of <a href=\"http://iopscience.iop.org/article/10.1088/"
       "0264-9381/27/19/194011\" target=\"_blank\">Accadia et al. (2010)</a>.")
page.div.close()
page.add(htmlio.write_arguments(
    [], start=int(args.gpsstart), end=int(args.gpsend), flag=args.state_flag)) 

# link segments file
page.h2('Segments')
page.p()
page.add('The full output segments are recorded in HDF5 format here:')
page.a(os.path.basename(segfile), href=segfile, target='_blank')
page.p.close()
# print state segments
if args.state_flag is not None:
    page.p('This analysis was executed over the following segments:')
    page.div(class_='panel-group', id_='accordion1')
    page.add(htmlio.write_flag_html(
        state, span, 'state', parent='accordion1', context='success',
        plotdir='', facecolor=(0.2, 0.8, 0.2), edgecolor='darkgreen',
        known={'facecolor': 'red', 'edgecolor': 'darkred', 'height': 0.4}))
    page.div.close()

page.p("The following channels were searched for evidence of scattering "
       "(yellow = weak evidence, red = strong evidence):")
page.div(class_='panel-group', id_='accordion1')

# -- find scattering evidence -------------------------------------------------

allchannels = ['%s:%s' % (args.ifo, c) for optic in args.optic for
               c in scattering.OPTIC_MOTION_CHANNELS[optic]]

if args.verbose:
    print("Reading all data:")
alldata = []
n = len(statea)
for i, seg in enumerate(statea):
    msg = "    {0}/{1} {2}: ".format(
        str(i).rjust(len(str(n))),
        n,
        str(seg),
    ) if args.verbose else False
    alldata.append(
        get_data(allchannels, seg[0], seg[1], frametype=args.frametype,
                 obs=args.ifo[0], verbose=msg, nproc=args.nproc).resample(128))

scatter_segments = DataQualityDict()

for i, channel in enumerate(sorted(alldata[0].keys())):
    if args.verbose:
        gprint("-- Processing %s --------------------" % channel)
    chanstr = re.sub('[:-]', '_', channel).replace('_', '-', 1)
    optic = channel.split('-')[1].split('_')[0]
    flag = '%s:DCH-%s_SCATTERING_GE_%s_HZ:1' % (args.ifo, optic, tstr)
    scatter_segments[channel] = DataQualityFlag(
        flag,
        isgood=False,
        description="Evidence for scattering above {0} Hz from {1} in "
                    "{2}".format(args.frequency_threshold, optic, channel),
    )
    # set up plot(s)
    plot = Plot(figsize=[12, 12])
    axes = {}
    axes['position'] = plot.add_subplot(411, xscale='auto-gps', xlabel='')
    axes['fringef'] = plot.add_subplot(412, sharex=axes['position'], xlabel='')
    axes['triggers'] = plot.add_subplot(413, sharex=axes['position'],
                                        xlabel='')
    axes['segments'] = plot.add_subplot(414, projection='segments',
                                        sharex=axes['position'])
    plot.subplots_adjust(bottom=.07, top=.95)
    histdata = dict((i, numpy.ndarray((0,))) for
                    i in scattering.FREQUENCY_MULTIPLIERS)
    linecolor = None
    fringecolors = [None] * len(scattering.FREQUENCY_MULTIPLIERS)
    # loop over state segments and find scattering fringes
    for j, seg in enumerate(statea):
        if args.verbose:
            gprint("Processing segment [%d .. %d)..." % seg)
        ts = alldata[j][channel]
        # get raw data and plot
        line = axes['position'].plot(ts, color=linecolor)[0]
        linecolor = line.get_color()
        # get fringe frequency and plot
        fringef = scattering.get_fringe_frequency(ts, multiplier=1)
        for k, m in list(enumerate(scattering.FREQUENCY_MULTIPLIERS))[::-1]:
            fm = fringef * m
            line = axes['fringef'].plot(
                fm, color=fringecolors[k],
                label=(j == 0 and r'$f\times%d$' % m or None))[0]
            fringecolors[k] = line.get_color()
            histdata[m].resize((histdata[m].size + fm.size,))
            histdata[m][-fm.size:] = fm.value
        # get segments and plot
        if ((fringef * args.multiplier_for_threshold).value.max() <
                args.frequency_threshold):
            scatter = DataQualityFlag(flag, known=[ts.span])
        else:
            scatter = (
                fringef * args.multiplier_for_threshold >=
                args.frequency_threshold * fringef.unit).to_dqflag(name=flag)
            scatter.active = scatter.active.protract(args.segment_padding)
            scatter.coalesce()
        axes['segments'].plot(scatter, facecolor='red', edgecolor='darkred',
                              known={'alpha': 0.6, 'facecolor': 'lightgray',
                                     'edgecolor': 'gray', 'height': 0.4},
                              height=0.8, y=0, label=' ')
        scatter_segments[channel] += scatter
        if args.verbose:
            gprint("    Found %d scattering segments" % (len(scatter.active)))
    if args.verbose:
        gprint("Completed channel, found %d segments in total."
               % len(scatter_segments[channel].active))

    # calculate efficiency and deadtime of veto
    deadtime = abs(scatter_segments[channel].active)
    try:
        deadtimepc = deadtime / livetime * 100
    except ZeroDivisionError:
        deadtimepc = 0.
    if args.verbose:
        gprint("Deadtime: %.2f%% (%.2f/%ds)"
               % (deadtimepc, deadtime, livetime))
    efficiency = in_segmentlist(highsnrtrigs[names[0]],
                                scatter_segments[channel].active).sum()
    try:
        efficiencypc = efficiency / len(highsnrtrigs) * 100
    except ZeroDivisionError:
        efficiencypc = 0.
    if args.verbose:
        gprint("Efficiency (SNR>=8): %.2f%% (%d/%d)"
               % (efficiencypc, efficiency, len(highsnrtrigs)))
    if deadtimepc == 0.:
        effdt = 0
    else:
        effdt = efficiencypc/deadtimepc
    if args.verbose:
        gprint("Efficiency/Deadtime: %.2f" % effdt)

    # finalize plot
    if args.verbose:
        gprint("Plotting...", end=' ')
    name = label_to_latex(channel) if rcParams["text.usetex"] else channel
    axes['position'].set_title("Scattering evidence in %s" % name)
    axes['position'].set_xlabel('')
    axes['position'].set_ylabel(r'Position [\textmu m]')
    axes['position'].text(
        0.01, 0.95, 'Optic position',
        transform=axes['position'].transAxes, va='top', ha='left',
        bbox={'edgecolor': 'none', 'facecolor': 'white', 'alpha': .5})
    axes['fringef'].plot(
        span, [args.frequency_threshold, args.frequency_threshold], 'k--')
    axes['fringef'].set_xlabel('')
    axes['fringef'].set_ylabel(r'Frequency [Hz]')
    axes['fringef'].yaxis.tick_right()
    axes['fringef'].yaxis.set_label_position("right")
    axes['fringef'].set_ylim(0, 2 * args.frequency_threshold)
    axes['fringef'].text(
        0.01, 0.95, 'Calculated fringe frequency',
        transform=axes['fringef'].transAxes, va='top', ha='left',
        bbox={'edgecolor': 'none', 'facecolor': 'white', 'alpha': .5})
    handles, labels = axes['fringef'].get_legend_handles_labels()
    axes['fringef'].legend(handles[::-1], labels[::-1], loc='upper right',
                           borderaxespad=0, bbox_to_anchor=(-0.01, 1.),
                           handlelength=1)

    axes['triggers'].scatter(
        trigs[names[0]],
        trigs[names[1]],
        c=trigs[names[2]],
        edgecolor='none',
    )
    name = args.main_channel
    if rcParams["text.usetex"]:
        name = label_to_latex(name)
    axes['triggers'].text(
        0.01, 0.95,
        '%s event triggers (Omicron)' % name,
        transform=axes['triggers'].transAxes, va='top', ha='left',
        bbox={'edgecolor': 'none', 'facecolor': 'white', 'alpha': .5})
    axes['triggers'].set_ylabel('Frequency [Hz]')
    axes['triggers'].set_ylim(0, 2 * args.frequency_threshold)
    axes['triggers'].colorbar(cmap='YlGnBu', clim=(3, 100), norm='log',
                              label='Signal-to-noise ratio')
    axes['segments'].set_ylim(-.55, .55)
    axes['segments'].text(
        0.01, 0.95,
        r'Time segments with $f\times4 > %.2f$ Hz' % args.frequency_threshold,
        transform=axes['segments'].transAxes, va='top', ha='left',
        bbox={'edgecolor': 'none', 'facecolor': 'white', 'alpha': .5})
    for ax in axes.values():
        ax.set_epoch(int(args.gpsstart))
        ax.set_xlim(*span)
    png = '%s_SCATTERING_%s_HZ-%s.png' % (chanstr, tstr, gpsstr)
    try:
        plot.save(png)
    except OverflowError as e:
        warnings.warn(str(e))
        plot.axes[1].set_ylim(0, args.frequency_threshold * 4)
        plot.refresh()
        plot.save(png)
    plot.close()
    if args.verbose:
        gprint("%s written." % png)

    # make histogram
    histogram = Plot(figsize=[12, 6])
    ax = histogram.gca()
    hrange = (0, 2 * args.frequency_threshold)
    for m, color in list(zip(histdata, fringecolors))[::-1]:
        if histdata[m].size:
            ax.hist(
                histdata[m], facecolor=color, alpha=.6, range=hrange,
                bins=50, histtype='stepfilled', label=r'$f\times%d$' % m,
                cumulative=-1, weights=ts.dx.value, bottom=1e-100, log=True)
        else:
            ax.plot(histdata[m], color=color, label=r'$f\times%d$' % m)
            ax.set_yscale('log')
    ax.set_ylim(.01, float(livetime))
    ax.set_ylabel('Time with fringe above frequency [s]')
    ax.set_xlim(*hrange)
    ax.set_xlabel('Frequency [Hz]')
    ax.set_title(axes['position'].get_title())
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles[::-1], labels[::-1], loc='upper right')
    hpng = '%s_SCATTERING_HISTOGRAM-%s.png' % (chanstr, gpsstr)
    histogram.save(hpng)
    histogram.close()
    if args.verbose:
        gprint("%s written." % hpng)

    # write HTML
    if deadtime != 0 and effdt > 2:
        context = 'danger'
    elif ((deadtime != 0 and effdt < 2) or
          (histdata[args.multiplier_for_threshold].size and
           histdata[args.multiplier_for_threshold].max() >=
              args.frequency_threshold/2.)):
        context = 'warning'
    else:
        context = 'default'
    page.div(class_='panel panel-%s' % context)
    page.div(class_='panel-heading')
    page.a(channel, class_="panel-title", href='#flag%s' % i,
           **{'data-toggle': 'collapse', 'data-parent': '#accordion'})
    page.div.close()
    page.div(id_='flag%s' % i, class_='panel-collapse collapse')
    img = htmlio.FancyPlot(png)
    page.add(htmlio.fancybox_img(img))
    himg = htmlio.FancyPlot(hpng)
    page.add(htmlio.fancybox_img(himg))
    page.div(class_='panel-body')
    segs = StringIO()
    if deadtime:
        page.p("%d segments were found predicting a scattering fringe above "
               "%.2f Hz." % (len(scatter_segments[channel].active),
                             args.frequency_threshold))
        page.table(class_='table table-condensed table-hover')
        page.tbody()
        page.tr()
        page.th('Deadtime')
        page.td('%.2f/%d seconds' % (deadtime, livetime))
        page.td('%.2f%%' % deadtimepc)
        page.tr.close()
        page.tr()
        page.th('Efficiency<br><small>(SNR&ge;8 and '
                'f<sub>peak</sub>&lt;%.2f Hz)</small>'
                % (2 * args.frequency_threshold))
        page.td('%d/%d events' % (efficiency, len(highsnrtrigs)))
        page.td('%.2f%%' % efficiencypc)
        page.tr.close()
        page.tr()
        page.th('Efficiency/Deadtime')
        page.td()
        page.td('%.2f' % effdt)
        page.tr.close()
        page.tbody.close()
        page.table.close()
        scatter_segments[channel].active.write(segs, format='segwizard',
                                               coltype=float)
        page.pre(segs.getvalue())
    else:
        page.p("No segments were found with scattering above %.2f Hz."
               % args.frequency_threshold)
    page.div.close()
    page.div.close()
    page.div.close()

# -- finalize -----------------------------------------------------------------

# write segments
scatter_segments.write(segfile, path="segments", overwrite=True)
if args.verbose:
    gprint("%s written" % segfile)

# write HTML
page.add(htmlio.write_footer())
page.div.close()
with open('index.html', 'w') as fp:
    fp.write(str(page))
if args.verbose:
    gprint("-- index.html written, all done --")
