#!/usr/bin/python

# Copyright (C) 2017 Michael Coughlin
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

""".
Gravitational-wave Electromagnetic Optimization

This script generates an optimized list of pointings and content for
reviewing gravitational-wave skymap likelihoods.

Comments should be e-mailed to michael.coughlin@ligo.org.

"""


import os, sys, glob, optparse, shutil, warnings
import numpy as np
np.random.seed(0)

import healpy as hp

import matplotlib
#matplotlib.rc('text', usetex=True)
matplotlib.use('Agg')
matplotlib.rcParams.update({'font.size': 20})
matplotlib.rcParams['contour.negative_linestyle'] = 'solid'
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

# =============================================================================
#
#                                    MAIN
#
# =============================================================================

warnings.filterwarnings("ignore")

filename1 = "../output/GW190425/0.900_DECam_ZTF_GIT/tiles_coverage_int.dat"
filename2 = "../output/GW190425/0.900_DECam_ZTF_GIT_Iterative/tiles_coverage_int.dat"
filename3 = "../output/GW190425/0.900_PS1_ATLAS/tiles_coverage_int.dat"
filename4 = "../output/GW190425/0.900_PS1_ATLAS_Iterative/tiles_coverage_int.dat"
filename5 = "../output/GW190425/0.900_TCA_OAJ_IRIS/tiles_coverage_int.dat"
filename6 = "../output/GW190425/0.900_TCA_OAJ_IRIS_Iterative/tiles_coverage_int.dat"

data1, data2 = np.loadtxt(filename1), np.loadtxt(filename2)
data3, data4 = np.loadtxt(filename3), np.loadtxt(filename4)
data5, data6 = np.loadtxt(filename5), np.loadtxt(filename6)

bins = np.linspace(0.0, 24.0, 25)

plotName = '../output/tiles_coverage_int.pdf'
fig, ax = plt.subplots(1, 3, figsize=(20, 8))
xlimits = [[0,1.0],[0,1.0],[0.0,1.0]]
ylimits = [[0,0.6],[0,0.6],[0,0.3]]
ylimits2 = [[0,6000],[0,6000],[0,1200]]
days = [ [[0.15,0.25]],
           [[0.30,0.9]],
           [[0.0,0.45],[0.80,1.0]] 
         ]
titles = ['GROWTH','PS1/ATLAS','GRANDMA']

for ii in range(3):

    xlim, ylim, ylim2 = xlimits[ii], ylimits[ii], ylimits2[ii]
    day = days[ii]
    ax3 = ax[ii].twinx()   # mirror them

    color = 'k'
    if ii == 0:
        tts, cum_probs, cum_areas = data1[:,0], data1[:,1], data1[:,2]
    elif ii == 1:
        tts, cum_probs, cum_areas = data3[:,0], data3[:,1], data3[:,2]
    elif ii == 2:
        tts, cum_probs, cum_areas = data5[:,0], data5[:,1], data5[:,2]

    tts = np.hstack((0.0,tts,1.0))
    cum_probs = np.hstack((0.0,cum_probs,cum_probs[-1]))
    cum_areas = np.hstack((0.0,cum_areas,cum_areas[-1]))

    ax[ii].plot(tts, cum_probs, color=color, linestyle='-', label='Original')
    ax3.plot(tts, cum_areas, color=color, linestyle='--')

    color = 'b'
    if ii == 0:
        tts, cum_probs, cum_areas = data2[:,0], data2[:,1], data2[:,2]
    elif ii == 1:
        tts, cum_probs, cum_areas = data4[:,0], data4[:,1], data4[:,2]
    elif ii == 2:
        tts, cum_probs, cum_areas = data6[:,0], data6[:,1], data6[:,2]

    tts = np.hstack((0.0,tts,1.0))
    cum_probs = np.hstack((0.0,cum_probs,cum_probs[-1]))
    cum_areas = np.hstack((0.0,cum_areas,cum_areas[-1]))

    ax[ii].plot(tts, cum_probs, color=color, linestyle='-', label='Iterative')
    ax3.plot(tts, cum_areas, color=color, linestyle='--')

    for jj, (day_min, day_max) in enumerate(day):
        rect = Rectangle((day_min, 0), day_max-day_min, 1, alpha=0.2, color='y')
        ax[ii].add_patch(rect)
        if jj > 0: continue
        if ii == 1:
            ax[ii].text(-0.35*(day_max-day_min) + (day_max+day_min)/2.0, 0.9*ylim[1], 'Unobservable', fontsize=16)
        elif ii == 0:
            ax[ii].text(-0.2 + (day_max+day_min)/2.0, 0.9*ylim[1], 'Unobservable', fontsize=16)
        else:
            ax[ii].text(-0.42*(day_max-day_min) + (day_max+day_min)/2.0, 0.9*ylim[1], 'Unobservable', fontsize=16)

    if ii == 1:
        ax[ii].legend(loc=8)

    ax[ii].set_title(titles[ii])
    ax[ii].set_xlim(xlim)
    ax[ii].set_ylim(ylim)
    ax3.set_ylim(ylim2)
    ax[ii].grid()

fig.text(0.5, 0.01, 'Time since event [days]', ha='center', fontsize=30)
fig.text(0.05, 0.5, 'Integrated Probability', va='center', rotation='vertical', fontsize=30)
fig.text(0.95, 0.5, 'Sky area [sq. deg.]', va='center', rotation='vertical', fontsize=30)

plt.subplots_adjust(wspace=0.40)

plt.show()
plt.savefig(plotName,dpi=200,bbox_inches='tight')
plt.close('all')

