#!python
# coding=utf-8
# Copyright (C) LIGO Scientific Collaboration (2015-)
#
# This file is part of the GW DetChar python package.
#
# GW DetChar is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# GW DetChar is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with GW DetChar.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import (division, print_function)

import os
import re
import multiprocessing
import sys

import numpy

from matplotlib import (use, rcParams)
use('agg')  # nopep8
from matplotlib import cm as mcm

from sklearn import linear_model
from sklearn.preprocessing import scale

from gwpy.table import Table
from pandas import set_option

from gwpy.plot import Plot
from gwpy.time import Time
from gwpy.detector import ChannelList
from gwpy.io import nds2 as io_nds2

from gwdetchar import (cli, lasso as gwlasso)
from gwdetchar.lasso import plot as gwplot
from gwdetchar.io import html as htmlio
from gwdetchar.io.datafind import get_data
from gwdetchar.plot import get_gwpy_tex_settings


# default frametypes
DEFAULT_FRAMETYPE = {
    'GDS-CALIB_STRAIN': '{ifo}_HOFT_C00',
    'DMT-SNSH_EFFECTIVE_RANGE_MPC.mean': 'SenseMonitor_hoft_{ifo}_M',
    'DMT-SNSW_EFFECTIVE_RANGE_MPC.mean': 'SenseMonitor_CAL_{ifo}_M',
}


# -- parse command line -------------------------------------------------------

parser = cli.create_parser(
    description=__doc__,
    formatter_class=cli.argparse.ArgumentDefaultsHelpFormatter)
cli.add_gps_start_stop_arguments(parser)
cli.add_ifo_option(parser)
cli.add_nproc_option(parser, default=1)
parser.add_argument('-J', '--nproc-plot', type=int, default=None,
                    help='number of processes to use for plotting')
parser.add_argument('-o', '--output-dir', default=os.curdir,
                    help='output directory for plots')
parser.add_argument('-f', '--channel-file', type=os.path.abspath,
                    help='path for channel file')
parser.add_argument('-T', '--trend-type', default='minute',
                    choices=['second', 'minute'],
                    help='type of trend for correlation')
parser.add_argument('-p', '--primary-channel',
                    default='{ifo}:DMT-SNSH_EFFECTIVE_RANGE_MPC.mean',
                    help='name of primary channel to use')
parser.add_argument('-P', '--primary-frametype', default=None,
                    help='frametype for --primary-channel, default: guess '
                         'by channel name')
parser.add_argument('-O', '--remove-outliers', type=float, default=None,
                    help='Std. dev. limit for removing outliers')
parser.add_argument('-t', '--threshold', type=float, default=0.0001,
                    help='threshold for making a plot')

psig = parser.add_argument_group('Signal processing options')
psig.add_argument('-b', '--band-pass', type=float, nargs=2, default=None,
                  metavar="FLOW FHIGH",
                  help='lower and upper frequencies for bandpass on h(t)')
psig.add_argument('-x', '--filter-padding', type=float, default=3.,
                  help='amount of time (seconds) to pad data for filtering')

lsig = parser.add_argument_group('Lasso options')
lsig.add_argument('-a', '--alpha', default=None, type=float,
                  help='alpha parameter for lasso fit')
lsig.add_argument('-C', '--no-cluster', action='store_true', default=False,
                  help='do not generate clustered channel plots')
lsig.add_argument('-c', '--cluster-coefficient', default=.85, type=float,
                  help='correlation coefficient threshold for clustering')
lsig.add_argument('-L', '--line-size-primary', default=1, type=float,
                  help='line width of primary channel')
lsig.add_argument('-l', '--line-size-aux', default=0.75, type=float,
                  help='line width of auxilary channel')

args = parser.parse_args()

# set TeX settings
tex_settings = get_gwpy_tex_settings()
rcParams.update(tex_settings)

start = int(args.gpsstart)
end = int(args.gpsend)
pad = args.filter_padding
auto_xlabel = ('Time [hours] from '
               + re.sub(r'\.0+', '',
                        Time(start, format='gps', scale='utc').iso)
               + ' UTC (%d)' % start)

# get primary channel frametype
primary = args.primary_channel.format(ifo=args.ifo)                             
range_is_primary = 'EFFECTIVE_RANGE_MPC' in args.primary_channel
if args.primary_frametype is None:
    try:
        args.primary_frametype = DEFAULT_FRAMETYPE[
            args.primary_channel.split(':')[1]].format(ifo=args.ifo)
    except KeyError as exc:
        raise type(exc)("Could not determine primary channel's frametype, "
                        "please specify with --primary-frametype")
        

if not os.path.isdir(args.output_dir):
    os.makedirs(args.output_dir)
os.chdir(args.output_dir)
nprocplot = args.nproc_plot or args.nproc

if args.band_pass:
    try:
        flower, fupper = args.band_pass
    except TypeError:
        flower, fupper = None

    print("-- Loading primary channel data")
    bandts = get_data(
        primary, start-pad, end+pad, verbose='Reading primary:',
        obs=args.ifo[0], frametype=args.primary_frametype, nproc=args.nproc)
    if flower < 0 or fupper >= float((bandts.sample_rate/2.).value):
        raise ValueError("bandpass frequency is out of range for this "
                         "channel, band (Hz): {0}, sample rate: {1}".format(
                             args.band_pass, bandts.sample_rate))

    # get darm BLRMS
    print("-- Filtering data")
    if args.trend_type == 'minute':
        stride = 60
    else:
        stride = 1
    if flower:
        darmblrms = (
            bandts.highpass(flower/2., fstop=flower/4.,
                            filtfilt=False, ftype='butter')
            .notch(60, filtfilt=False)
            .bandpass(flower, fupper, fstop=[flower/2., fupper*1.5],
                      filtfilt=False, ftype='butter')
            .crop(start, end).rms(stride))
        darmblrms.name = '%s %s-%s Hz BLRMS' % (primary, flower, fupper)
    else:
        darmblrms = bandts.notch(60).crop(start, end).rms(stride)
        darmblrms.name = '%s RMS' % primary

    primaryts = darmblrms

else:
    # load primary channel data
    print("-- Loading primary channel data")
    primaryts = get_data(primary, start, end, frametype=args.primary_frametype,
                         obs=args.ifo[0], verbose='Reading:', nproc=args.nproc)

if args.remove_outliers:
    print("-- Removing outliers above %f sigma" % args.remove_outliers)
    gwlasso.remove_outliers(primaryts, args.remove_outliers)

primary_mean = numpy.mean(primaryts.value)
primary_std = numpy.std(primaryts.value)


def descaler(l, *coef):
    if coef:
        return [((x * primary_std * coef[0]) + primary_mean) for x in l]
    else:
        return [((x * primary_std) + primary_mean) for x in l]


# get aux data
print("-- Loading auxiliary channel data")
if args.channel_file is None:
    host, port = io_nds2.host_resolution_order(args.ifo)[0]
    channels = ChannelList.query_nds2('*.mean', host=host, port=port,
                                      type='m-trend')
else:
    with open(args.channel_file, 'r') as f:
        channels = [name.rstrip('\n') for name in f]
nchan = len(channels)
print("Identified %d channels" % nchan)

if args.trend_type == 'minute':
    frametype = '%s_M' % args.ifo  # for minute trends
else:
    frametype = '%s_T' % args.ifo  # for second trends

auxdata = get_data(
    channels, start, end, verbose='Reading:', obs=args.ifo[0],
    frametype=frametype, nproc=args.nproc, pad=0)

# -- removes flat data to be re-introdused later

print('-- Pre-processing auxiliary channel data')

auxdata = gwlasso.remove_flat(auxdata)
flattab = Table(data=(list(set(channels) - set(auxdata.keys())),),
                names=('Channels',))
print('Removed {0} channels with flat data'.format(len(flattab)))
print('{0} channels remaining'.format(len(auxdata)))

# -- remove bad data

print("Removing any channels with bad data...")
nbefore = len(auxdata)
auxdata = gwlasso.remove_bad(auxdata)
nafter = len(auxdata)
print('Removed {0} channels with bad data'.format(nbefore - nafter))
print('{0} channels remaining'.format(nafter))
data = numpy.array([scale(ts.value) for ts in auxdata.values()]).T

# -- perform lasso regression -------------------------------------------------

# create model
print('-- Fitting data to target')
target = scale(primaryts.value)
model = gwlasso.fit(data, target, alpha=args.alpha)
print('Alpha:', model.alpha)

# restructure results for convenience
allresults = Table(
    data=(list(auxdata.keys()), model.coef_, numpy.abs(model.coef_)),
    names=('Channel', 'Lasso coefficient', 'rank'))
allresults.sort('rank')
allresults.reverse()
useful = allresults['rank'] > 0
allresults.remove_column('rank')
results = allresults[useful]  # non-zero coefficient
zeroed = allresults[numpy.invert(useful)]  # zero coefficient

# extract data for useful channels
nonzerodata = {name: auxdata[name] for name in results['Channel']}
nonzerocoef = {name: coeff for name, coeff in results.as_array()}

# print results
print('Found {} channels with |Lasso coefficient| >= {}'.format(
    len(results), args.threshold))
print(results)

# convert to pandas
set_option('max_colwidth', 200)
df = results.to_pandas()
df.index += 1

# write results to files
gpsstub = '%d-%d' % (start, end-start)
resultsfile = '%s-LASSO_RESULTS-%s.txt' % (args.ifo, gpsstub)
results.write(resultsfile, format='ascii', overwrite=True)
zerofile = '%s-ZERO_COEFFICIENT_CHANNELS-%s.txt' % (args.ifo, gpsstub)
zeroed.write(zerofile, format='ascii', overwrite=True)
flatfile = '%s-FLAT_CHANNELS-%s.txt' % (args.ifo, gpsstub)
flattab.write(flatfile, format='ascii', overwrite=True)

# -- generate lasso plots

modelFit = model.predict(data)

re_delim = re.compile(r'[:_-]')
form = '%%.%dd' % len(str(nchan))
p1 = (.1, .15, .9, .9)  # global plot defaults for plot1, lasso model

times = primaryts.times.value
xlim = primaryts.span
cmap = mcm.get_cmap('tab20')
colors = [cmap(i) for i in numpy.linspace(0, 1, len(nonzerodata)+1)]

plot = Plot(figsize=(12, 6))
plot.subplots_adjust(*p1)
ax = plot.gca(xscale='auto-gps', epoch=start, xlim=xlim)
ax.plot(times, descaler(target), label=primary.replace('_', r'\_'),
        color='black', linewidth=args.line_size_primary)
ax.plot(times, descaler(modelFit), label='Lasso model',
        linewidth=args.line_size_aux)
if range_is_primary:
    ax.set_ylabel('Sensitive range [Mpc]')
    ax.set_title('Lasso Model of Range')
else:
    ax.set_ylabel('Primary Channel Units')
    ax.set_title('Lasso Model of Primary Channel')
ax.legend(loc='best')
plot1 = gwplot.save_figure(plot, '%s-LASSO_MODEL-%s.png' % (args.ifo, gpsstub))

# summed contributions
plot = Plot(figsize=(8, 5))
plot.subplots_adjust(*p1)
ax = plot.gca(xscale='auto-gps', epoch=start, xlim=xlim)
ax.plot(times, descaler(target), label=primary.replace('_', r'\_'),  # primary
        color='black', linewidth=args.line_size_primary)
summed = 0
for i, name in enumerate(results['Channel']):
    summed += scale(nonzerodata[name].value) * nonzerocoef[name]
    if i:
        label = 'Channels 1-{0}'.format(i+1)
    else:
        label = 'Channel 1'
    ax.plot(times, descaler(summed), label=label, color=colors[i],
            linewidth=args.line_size_aux)
if range_is_primary:
    ax.set_ylabel('Sensitive range [Mpc]')
else:
    ax.set_ylabel('Primary Channel Units')
ax.set_title('Summations of Channel Contributions to Model')

# save legend as a separate figure
#     we save the legend first because save_figure() calls fig.close()
#     which clears the axes
plot2_legend = gwplot.save_legend(
    ax, '%s-LASSO_CHANNEL_SUMMATION_LEGEND-%s.png' % (args.ifo, gpsstub),
)

# save figure
plot2 = gwplot.save_figure(
    plot, '%s-LASSO_CHANNEL_SUMMATION-%s.png' % (args.ifo, gpsstub))


# individual contributions
plot = Plot(figsize=(8, 5))
plot.subplots_adjust(*p1)
ax = plot.gca(xscale='auto-gps', epoch=start, xlim=xlim)
ax.plot(times, descaler(target), label=primary.replace('_', r'\_'),  # primary
        color='black', linewidth=args.line_size_primary)
for i, name in enumerate(results['Channel']):
    this = descaler(scale(nonzerodata[name].value) * nonzerocoef[name])
    if i:
        label = 'Channels 1-{0}'.format(i+1)
    else:
        label = 'Channel 1'
    ax.plot(times, this, label=name.replace('_', r'\_'), color=colors[i],
            linewidth=args.line_size_aux)
if range_is_primary:
    ax.set_ylabel('Sensitive range [Mpc]')
else:
    ax.set_ylabel('Primary Channel Units')
ax.set_title('Individual Channel Contributions to Model')

plot3_legend = gwplot.save_legend(
    ax, '%s-LASSO_CHANNEL_CONTRIBUTIONS_LEGEND-%s.png' % (args.ifo, gpsstub))
plot3 = gwplot.save_figure(
    plot, '%s-LASSO_CHANNEL_CONTRIBUTIONS-%s.png' % (args.ifo, gpsstub))

# -- process aux channels, making plots

print("-- Processing channels")
counter = multiprocessing.Value('i', 0)
p4 = (.1, .1, .9, .95)  # global plot defaults for plot4, timeseries subplots


def process_channel(input_,):
    gwplot.configure_mpl()
    chan = input_[1][0]
    ts = input_[1][1]
    lassocoef = nonzerocoef[chan]
    zeroed = lassocoef == 0

    if zeroed:
        plot4 = None
        plot5 = None
        plot6 = None
        pcorr = None
    else:
        plot4 = None
        plot5 = None
        plot6 = None
        if args.trend_type == 'minute':
            pcorr = numpy.corrcoef(ts.value, primaryts.value)[0, 1]
        else:
            pcorr = 0.0
        if abs(lassocoef) < args.threshold:
            with counter.get_lock():
                counter.value += 1
            pc = 100 * counter.value / len(nonzerodata)
            print("Completed [%d/%d] %3d%% %-50s"
                  % (counter.value, len(nonzerodata), pc,
                     '(%s)' % str(chan)),
                  end='\r')
            sys.stdout.flush()
            return chan, lassocoef, plot4, plot5, plot6, ts

        # create time series subplots
        fig = Plot(figsize=(12, 12))
        fig.subplots_adjust(*p4)
        ax1 = fig.add_subplot(2, 1, 1, xscale='auto-gps', epoch=start)
        ax1.plot(primaryts, label=primary.replace('_', r'\_'),
                 color='black', linewidth=args.line_size_primary)
        ax2 = fig.add_subplot(2, 1, 2, sharex=ax1, xlim=xlim)
        ax2.plot(ts, label=chan.replace('_', r'\_'),
                 linewidth=args.line_size_aux)
        if range_is_primary:
            ax1.set_ylabel('Sensitive range [Mpc]')
        else:
            ax1.set_ylabel('Primary channel units')
        ax2.set_ylabel('Channel units')
        for ax in fig.axes:
            ax.legend(loc='best')
        channelstub = re_delim.sub('_', str(chan)).replace('_', '-', 1)
        plot4 = gwplot.save_figure(
            fig, '%s_TRENDS-%s.png' % (channelstub, gpsstub))

        # create scaled, sign-corrected, and overlayed timeseries
        tsscaled = scale(ts.value)
        if lassocoef < 0:
            tsscaled = numpy.negative(tsscaled)
        fig = Plot(figsize=(12, 6))
        fig.subplots_adjust(*p1)
        ax = fig.gca(xscale='auto-gps', epoch=start, xlim=xlim)
        ax.plot(times, descaler(target), label=primary.replace('_', r'\_'),
                color='black', linewidth=args.line_size_primary)
        ax.plot(times, descaler(tsscaled), label=chan.replace('_', r'\_'),
                linewidth=args.line_size_aux)
        if range_is_primary:
            ax.set_ylabel('Sensitive range [Mpc]')
        else:
            ax.set_ylabel('Primary Channel Units')
        ax.legend(loc='best')
        plot5 = gwplot.save_figure(
            fig, '%s_COMPARISON-%s.png' % (channelstub, gpsstub))

        # scatter plot
        tsCopy = ts.value.reshape(-1, 1)
        primarytsCopy = primaryts.value.reshape(-1, 1)
        primaryReg = linear_model.LinearRegression()
        primaryReg.fit(tsCopy, primarytsCopy)
        primaryFit = primaryReg.predict(tsCopy)
        fig = Plot(figsize=(12, 6))
        fig.subplots_adjust(*p1)
        ax = fig.gca()
        ax.set_xlabel(chan.replace('_', r'\_') + ' [Channel units]')
        if range_is_primary:
            ax.set_ylabel('Sensitive range [Mpc]')
        else:
            ax.set_ylabel('Primary channel units')
        ax.text(.9, .1, 'r = ' + str('{0:.2}'.format(pcorr)),
                verticalalignment='bottom', horizontalalignment='right',
                transform=ax.transAxes, color='black', size=20,
                bbox=dict(boxstyle='square', facecolor='white', alpha=.75,
                          edgecolor='black'))
        ax.scatter(ts.value, primaryts.value, color='red')
        ax.plot(ts.value, primaryFit, color='black')
        ax.autoscale_view(tight=False, scalex=True, scaley=True)
        plot6 = gwplot.save_figure(
            fig, '%s_SCATTER-%s.png' % (channelstub, gpsstub))

    # increment counter and print status
    with counter.get_lock():
        counter.value += 1
        pc = 100 * counter.value / len(nonzerodata)
        print("Completed [%d/%d] %3d%% %-50s"
              % (counter.value, len(nonzerodata), pc,
                 '(%s)' % str(chan)),
              end='\r')
        sys.stdout.flush()
    return chan, lassocoef, plot4, plot5, plot6, ts


# process channels
pool = multiprocessing.Pool(nprocplot)
results = pool.map(process_channel, enumerate(list(nonzerodata.items())))
results = sorted(results, key=lambda x: abs(x[1]), reverse=True)

#  generate clustered time series plots
counter = multiprocessing.Value('i', 0)
p7 = (.135, .15, .95, .9)  # global plot defaults for plot7, clusters
max_correlated_channels = 20


def generate_cluster(input_,):
    gwplot.configure_mpl()
    currentchan = input_[1][0]
    currentts = input_[1][5]
    current = input_[0]
    cluster_threshold = args.cluster_coefficient
    plot7 = None
    plot7_legend = None
    plot7_list = None

    if current < len(nonzerodata):
        cluster = []
        for i, otheritem in enumerate(list(auxdata.items())):
            chan_, ts_ = otheritem
            if chan_ != currentchan:
                pcorr = numpy.corrcoef(currentts.value, ts_.value)[0, 1]
                if abs(pcorr) >= cluster_threshold:
                    stub = re_delim.sub('_', chan_).replace('_', '-', 1)
                    cluster.append([i, ts_, pcorr, chan_, stub])

        if cluster:
            # write cluster table to file
            cluster = sorted(cluster, key=lambda x: abs(x[2]), reverse=True)
            clustertab = Table(data=list(zip(*cluster))[2:4],
                               names=('Channel', 'Pearson Coefficient'))
            plot7_list = '%s_CLUSTER_LIST-%s.txt' % (
                re_delim.sub('_', str(currentchan)).replace('_', '-', 1),
                gpsstub)
            clustertab.write(plot7_list, format='ascii', overwrite=True)

            ncluster = min(len(cluster), max_correlated_channels)
            colors2 = [cmap(i) for i in numpy.linspace(0, 1, ncluster+1)]

            # plot
            fig = Plot(figsize=(8, 5))
            fig.subplots_adjust(*p7)
            ax = fig.gca(xscale='auto-gps')
            ax.plot(
                times, scale(currentts.value)*numpy.sign(input_[1][1]),
                label=currentchan.replace('_', r'\_'),
                linewidth=args.line_size_aux, color=colors[0])

            for i in range(0, ncluster):
                this = cluster[i]
                ax.plot(
                    times,
                    scale(this[1].value) * numpy.sign(input_[1][1]) *
                    numpy.sign(this[2]),
                    color=colors2[i+1],
                    linewidth=args.line_size_aux,
                    label=('{0}, r = {1:.2}'.format(
                        cluster[i][3].replace('_', r'\_'), cluster[i][2])),
                )

            ax.margins(x=0)
            ax.set_ylabel('Scaled amplitude [arbitrary units]')
            ax.set_title('Highly Correlated Channels')

            plot7_legend = gwplot.save_legend(
                ax, '%s_CLUSTER_LEGEND-%s.png' % (
                    re_delim.sub('_', str(currentchan)).replace('_', '-', 1),
                    gpsstub))
            plot7 = gwplot.save_figure(fig, '%s_CLUSTER-%s.png' % (
                re_delim.sub('_', str(currentchan))
                        .replace('_', '-', 1),
                gpsstub))

    with counter.get_lock():
        counter.value += 1
        pc = 100 * counter.value / len(nonzerodata)
        print("Completed [%d/%d] %3d%% %-50s"
              % (counter.value, len(nonzerodata), pc,
                 '(%s)' % str(currentchan)),
              end='\r')
        sys.stdout.flush()
    return plot7, plot7_legend, plot7_list


if args.no_cluster is False:
    print("-- Generating clusters")
    pool = multiprocessing.Pool(nprocplot)
    clusters = pool.map(generate_cluster, enumerate(results))

channelsfile = '%s-CHANNELS-%s.txt' % (args.ifo, gpsstub)
numpy.savetxt(channelsfile, channels, fmt='%s')


# format results table
def format_indices(value):
    return "<div style='text-align:center;'>{}</div>".format(value)


def format_channels(value):
    return ("<div style='text-align:left;"
            "font-family:monospace,monospace;'>{}</div>".format(value))


def format_coefficients(value):
    return "<div style='text-align:left;'>{}</div>".format(value)


def style_table(html_table):
    html_table = (
        html_table.replace(
            '<table',
            ("<table style="
             "'border-collapse:collapse;"
             "border-style:hidden;"
             "width:85%; '"))
        .replace(
            '<tr>',
            ('''<tr style='''
             '''"border-left:0 !important;'''
             '''border-right:0 !important;"'''
             '''onmouseover="this.style.backgroundColor='f2f2f2';"'''
             '''onmouseout="this.style.backgroundColor='';">'''))
        .replace(
            '<th>',
            ("<th style="
             "'border-left:0 !important;"
             "border-right:0 !important;"
             "padding-left:10px;"
             "padding-right:10px;"
             "padding-top:5px; "
             "padding-bottom:5px;'>"))
        .replace(
            '<td>',
            ("<td style = "
             "'border-left:0 !important;"
             "border-right:0 !important;'>")))
    return html_table


# write html
title = '%s Lasso Slow Correlations: %d-%d' % (args.ifo, start, end)
page = htmlio.new_bootstrap_page(title=title)
page.div(class_='container')

page.div(class_='page-header')
page.h1(title)
page.div.close()  # page-header

content = [
    ('Primary channel',
     '{} ({})'.format(primary, args.primary_frametype or '-')),
    ('Channels searched',
     '{} ({})'.format(nchan, "<a href= %s target='_blank'>channel list</a>"
                             % channelsfile)),
    ('Number of flat channels',
     '{} ({})'.format(len(flattab),
                      "<a href= %s target='_blank'>flat channel list</a>"
                      % (flatfile))),
    ('Lasso coefficient threshold', '{}'.format(args.threshold)),
    ('Sigma for outlier removal', '{}'.format(args.remove_outliers)),
    ('Cluster coefficient threshold', '{}'.format(args.cluster_coefficient))]
if args.band_pass:
    content.append(
        ('Filter settings',
         'Flow = {} Hz, Fhigh = {} Hz '.format(flower, fupper)))
page.add(htmlio.write_arguments(content, start=start, end=end))

page.h2('Model Information')

page.div(class_='model')
page.div(class_='model-body')

page.div(class_='model-info')
page.add(htmlio.write_param('Model', 'Lasso'))
page.add(htmlio.write_param(
    'Non-zero coefficients', '%d' % numpy.count_nonzero(model.coef_)))
page.add(htmlio.write_param('Alpha', '%g' % model.alpha))
page.add(htmlio.write_param(
    'Zero coefficients',
    '%d (%s)' % (len(zeroed),
                 "<a href= %s target='_blank'>zeroed channel list</a>"
                 % (zerofile))))

page.div(class_='results-table', align='center')
page.p('<br /><br />%s' % style_table(df.to_html(
    index=True,
    formatters={
        'Lasso coefficient': format_coefficients,
        'Channel': format_channels,
        '__index__': format_indices},
    escape=False)))
page.p.close()
page.div.close()  # results-table
page.div.close()  # model-info

page.p('<br /><br />')

page.div(class_='container', id_='primary-lasso')
img1 = htmlio.FancyPlot(plot1)
page.add(htmlio.fancybox_img(img1))  # primary lasso plot
page.div.close()  # primary-lasso

page.div(class_='container', id_='channel-summation')
img2 = htmlio.FancyPlot(plot2)
img2_legend = htmlio.FancyPlot(plot2_legend)
page.add(htmlio.scaffold_plots([img2, img2_legend], nperrow=2))
page.div.close()  # channel-summation

page.div(class_='container', id_='channels-and-primary')                           
img3 = htmlio.FancyPlot(plot3)                                                  
img3_legend = htmlio.FancyPlot(plot3_legend)                                    
page.add(htmlio.scaffold_plots([img3, img3_legend], nperrow=2))                 
page.div.close()  # channels-and-primary

page.div.close()  # model-body
page.div.close()  # model

# results
page.h2('Top Channels')

page.div(class_='panel-group', id_='results')

# for each auxiliary channel create information container and put plots in it
for i, (ch, lassocoef, plot4, plot5, plot6, ts) in enumerate(results):
    # set container color/context based on lasso coefficient
    if lassocoef == 0:
        break
    elif abs(lassocoef) < args.threshold:
        h = '%s [lasso coefficient = %.4f] (Below threshold)' % (ch, lassocoef)
    else:
        h = '%s [lasso coefficient = %.4f]' % (ch, lassocoef)
    if ((lassocoef is None) or (lassocoef == 0)
            or (abs(lassocoef) < args.threshold)):
        context = 'panel-default'
    elif abs(lassocoef) >= .5:
        context = 'panel-danger'
    elif abs(lassocoef) >= .2:
        context = 'panel-warning'
    else:
        context = 'panel-info'
    page.div(class_='panel %s' % context)

    # heading
    page.div(class_='panel-heading')
    page.a(h, class_='panel-title', href='#channel%d' % i,
           **{'data-toggle': 'collapse', 'data-parent': '#results'})
    page.div.close()  # panel-heading
    # body
    page.div(id_='channel%d' % i, class_='panel-collapse collapse')
    page.div(class_='panel-body')
    if lassocoef is None:
        page.p('The amplitude data for this channel is flat (does not change)'
               ' for the chosen time period.')
    elif abs(lassocoef) < args.threshold:
        page.p('Lasso coefficient below the threshold of %g.'
               % (args.threshold))
    else:
        for p in (plot4, plot5, plot6):
            pimg = htmlio.FancyPlot(p)
            page.add(htmlio.fancybox_img(pimg))
        if args.no_cluster is False:
            if clusters[i][0] is None:
                page.p("<font size = '3'><br />No channels were highly"
                       " correlated with this channel.</font>")
            else:
                page.div(class_='container', id_='clusters')
                cimg = htmlio.FancyPlot(clusters[i][0])
                cimg_legend = htmlio.FancyPlot(clusters[i][1])
                page.add(htmlio.scaffold_plots([cimg, cimg_legend], nperrow=2))
                page.div.close()  # clusters
                if clusters[i][2] is not None:
                    page.p("<font size = '3'><br />Only the first %d highly"
                           " correlated channels are reported in this"
                           " figure. <a href= %s target='_blank'>Here</a>"
                           " is the full list.</font>"
                           % (max_correlated_channels, clusters[i][2]))
    page.div.close()  # panel-body
    page.div.close()  # panel-collapse
    page.div.close()  # panel
page.div.close()  # panel-group
page.add(htmlio.write_footer())
page.div.close()  # container
with open('index.html', 'w') as f:
    print(str(page), file=f)

print("-- Process Completed")
