#!/usr/bin/python

# Copyright (C) 2017 Michael Coughlin
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

""".
Gravitational-wave Electromagnetic Optimization

This script generates an optimized list of pointings and content for
reviewing gravitational-wave skymap likelihoods.

Comments should be e-mailed to michael.coughlin@ligo.org.

"""


import os, sys, glob, optparse, shutil, warnings
import copy
import numpy as np
np.random.seed(0)

import healpy as hp
from astropy.table import unique, vstack, Table
from astropy import units as u
from astropy.time import Time
from astropy import cosmology
from astropy.coordinates import Distance
from astropy.coordinates import SkyCoord

import matplotlib
#matplotlib.rc('text', usetex=True)
matplotlib.use('Agg')
matplotlib.rcParams.update({'font.size': 16})
matplotlib.rcParams['contour.negative_linestyle'] = 'solid'
import matplotlib.pyplot as plt
from matplotlib import cm

import ligo.skymap.distance as ligodist

import gwemopt.utils, gwemopt.plotting
import gwemopt.moc, gwemopt.tiles 

__author__ = "Michael Coughlin <michael.coughlin@ligo.org>"
__version__ = 1.0
__date__    = "6/17/2017"

# =============================================================================
#
#                               DEFINITIONS
#
# =============================================================================

def parse_commandline():
    """@Parse the options given on the command-line.
    """
    parser = optparse.OptionParser(usage=__doc__,version=__version__)

    #parser.add_option("-s", "--skymap", help="GW skymap.", default='../data/S200115j/LALInference.fits.gz')
    #parser.add_option("-s", "--skymap", help="GW skymap.", default='../data/S200115j/bayestar.fits.gz')
    #parser.add_option("--fields", help="Observed fields.", default='../data/S200115j/ZTF_fields.dat,../data/S200115j/TCA_fields.dat,../data/S200115j/TCH_fields.dat,../data/S200115j/TRE_fields.dat')

    parser.add_option("-s", "--skymap", help="GW skymap.", default='../data/S200213t/LALInference.fits.gz')
    parser.add_option("--fields", help="Observed fields.", default='../data/S200213t/ZTF_fields.dat,../data/S200213t/TCA_fields.dat,../data/S200213t/TCH_fields.dat,../data/S200213t/TRE_fields.dat,../data/S200213t/OAJ_fields.dat')

    parser.add_option("--transients", help="Transient list.", default='../data/GW190425/transients.dat')

    parser.add_option("-c", "--configDirectory", help="GW-EM config file directory.", default ="../config/")
    #parser.add_option("-o", "--outputDir", help="output directory",default="../output/S200115j")
    parser.add_option("-o", "--outputDir", help="output directory",default="../output/S200213t")

    #parser.add_option("-t", "--telescope", help="Telescope.", default ="ZTF,TCA,TCH,TRE")
    parser.add_option("-t", "--telescope", help="Telescope.", default ="ZTF,TCA,TCH,TRE,OAJ")
    #parser.add_option("-g", "--gps", help="Event time GPS.", default=1263097398, type=float)
    parser.add_option("-g", "--gps", help="Event time GPS.", default=1265627458.327981, type=float)

    parser.add_option("-m", "--tmax", help="Max time post trigger.", default=3.0, type=float)

    parser.add_option("--nside",default=256,type=int)
    parser.add_option("-f","--filter",default="1,2")
    parser.add_option("--rotation",default=120.0,type=float)

    parser.add_option("--catalogFile", default="../catalogs/CLU.hdf5")

    parser.add_option("--doTransients",action="store_true",default=False)

    parser.add_option("-v", "--verbose", action="store_true", default=False,
                      help="Run verbosely. (Default: False)")

    opts, args = parser.parse_args()

    # show parameters
    if opts.verbose:
        print >> sys.stderr, ""
        print >> sys.stderr, "running gwemopt_run..."
        print >> sys.stderr, "version: %s"%__version__
        print >> sys.stderr, ""
        print >> sys.stderr, "***************** PARAMETERS ********************"
        for o in opts.__dict__.items():
          print >> sys.stderr, o[0]+":"
          print >> sys.stderr, o[1]
        print >> sys.stderr, ""

    return opts

# =============================================================================
#
#                                    MAIN
#
# =============================================================================

warnings.filterwarnings("ignore")

# Parse command line
opts = parse_commandline()
filters = [int(x) for x in opts.filter.split(",")]

t = Table.read(opts.catalogFile)
name = t["name"]
ra, dec = t["ra"], t["dec"]
sfr_fuv, mstar = t["sfr_fuv"], t["mstar"]
distmpc, magb = t["distmpc"], t["magb"]
a, b2a, pa = t["a"], t["b2a"], t["pa"]
btc = t["btc"]

catalog = SkyCoord(ra=ra,
                   dec=dec, 
                   distance=distmpc*u.Mpc,
                   frame='icrs')

params = {}
params["config"] = {}
configFiles = glob.glob("%s/*.config"%opts.configDirectory)
for configFile in configFiles:
    telescope = configFile.split("/")[-1].replace(".config","")
    if not telescope in opts.telescope.split(","): continue
    params["config"][telescope] = gwemopt.utils.readParamsFromFile(configFile)
    params["config"][telescope]["telescope"] = telescope
    params["config"][telescope]["tesselation"] = np.loadtxt(params["config"][telescope]["tesselationFile"],usecols=(0,1,2),comments='%')
    params["config"][telescope]["tot_obs_time"] = 1.0

transients = Table.read(opts.transients, format="ascii",
                        names=("name","ra","dec"))

params["outputDir"] = opts.outputDir

if not os.path.isdir(params["outputDir"]):
    os.makedirs(params["outputDir"])

params["tilesType"] = "moc"
params["doMinimalTiling"] = False
params["doParallel"] = False
params["telescopes"] = opts.telescope.split(",")
params["nside"] = opts.nside
params["doChipGaps"] = False
params["doSingleExposure"] = False
params["powerlaw_n"], params["powerlaw_cl"], params["powerlaw_dist_exp"] = 0.0, 0.9, 0.0
params["gpstime"] = 1000000000
params["Tobs"] = np.array([0.0,1.0])
params["exposuretimes"] = [30.0]
params["rotation"] = [opts.rotation,0,0]

params = gwemopt.utils.params_checker(params)
#params = gwemopt.segments.get_telescope_segments(params)
moc_structs = gwemopt.moc.create_moc(params)

map_struct = {}
healpix_data = hp.read_map(opts.skymap, field=(0,1,2,3), verbose=False)

distmu_data = healpix_data[1]
distsigma_data = healpix_data[2]
prob_data = healpix_data[0]
norm_data = healpix_data[3]

map_struct["distmu"] = distmu_data
map_struct["distsigma"] = distsigma_data
map_struct["prob"] = prob_data
map_struct["distnorm"] = norm_data
map_struct["ipix_keep"] = []

map_struct["prob"], map_struct["distmu"],map_struct["distsigma"], map_struct["distnorm"] = ligodist.ud_grade(map_struct["prob"],map_struct["distmu"],map_struct["distsigma"],opts.nside)

pixarea_deg2 = hp.nside2pixarea(opts.nside, degrees=True)
map_struct["pixarea_deg2"] = pixarea_deg2

if opts.doTransients:
    for transient in transients:
        print(transient["name"])
        idx = hp.ang2pix(opts.nside, transient["ra"],
                         transient["dec"], lonlat=True)
        d = Distance(map_struct["distmu"][idx], u.Mpc)
        if map_struct["distsigma"][idx] > map_struct["distmu"][idx]:
            map_struct["distsigma"][idx] = map_struct["distmu"][idx]-1
            print('error greater than mean, setting close to equal')
        d1, d2 = Distance(map_struct["distmu"][idx]-map_struct["distsigma"][idx], u.Mpc), Distance(map_struct["distmu"][idx]+map_struct["distsigma"][idx], u.Mpc)
        print('Distance: %.0f +- %.0f, %.3f +- %.3f' % (map_struct["distmu"][idx], map_struct["distsigma"][idx],d.z,np.mean(np.abs([d.z-d1.z,d.z-d2.z]))))
    
        transient_coord = SkyCoord(ra=transient["ra"]*u.degree,
                                   dec=transient["dec"]*u.degree,
                                   frame='icrs')
        sep = catalog.separation(transient_coord).arcsec
        idx2 = np.argsort(sep)
        print('Galaxies')
        for ii in range(5):
            idy = idx2[ii]
            if sep[idy] > 100: continue
            dg = Distance(distmpc[idy], u.Mpc)
            print('%s %.5f %.5f %.5f %.5f %.5f' % (name[idy], ra[idy], dec[idy],
                                                   sep[idy], distmpc[idy], dg.z))  
    
        print(' ')

tiles_structs = gwemopt.tiles.moc(params, map_struct, moc_structs,
                                  doSegments=False)

fieldsfiles = opts.fields.split(",")
telescopes = opts.telescope.split(",")

color2 = 'cyan'
color1 = 'cornflowerblue'
color3 = 'darkgreen'
color4 = 'darkmagenta'
color5 = 'pink'

colors = [color1,color2,color3,color4,color5]

unit='Gravitational-wave probability'
cbar=False
cmap = plt.get_cmap('cylon')

plotName = os.path.join(params["outputDir"],'tiles.pdf')

fig = plt.figure(figsize=(12, 8))

gs = fig.add_gridspec(5, 1)
ax1 = fig.add_subplot(gs[0:3, 0], projection='astro hours mollweide')
ax2 = fig.add_subplot(gs[3:5, 0])
#ax3 = ax2.twinx()   # mirror them

plt.axes(ax1)

hp.mollview(map_struct["prob"],title='',unit=unit,cbar=cbar,rot=[opts.rotation,0,0],cmap=cmap, hold=True)

hp.graticule(verbose=False, dmer=3*360.0/24.0, dpar=30.0)
plt.grid(True)
lons = np.arange(0.0,24,3.0)
lats = np.zeros(lons.shape)
for lon, lat in zip(lons,lats):
    hp.projtext(lon*360.0/24.0,lat,"%.0f hr"%lon,lonlat=True,color='k')
lats = np.arange(-60.0,90,30.0)
lons = 120*np.ones(lons.shape)
for lon, lat in zip(lons,lats):
    if lat == 0.0: continue
    hp.projtext(lon,lat,"%.0f"%lat,lonlat=True,color='k')

ax = plt.gca()
labels = []
for kk, telescope in enumerate(tiles_structs):
    labels.append(telescope)
    tiles_struct = tiles_structs[telescope]
    color = colors[kk]

    idx = telescopes.index(telescope)
    fields = fieldsfiles[idx]

    table = Table.read(fields, format="ascii",
                       names=("fieldID","fId","time","limmag","exposure_time"))
    

    start_time = Time(opts.gps, format='gps').mjd
    table["mjd"] = Time(table["time"], format='isot').mjd - start_time
    table.sort('mjd')

    table = table[table["mjd"] > 0]
    table = table[table["mjd"] < opts.tmax]

    if len(table) == 0: continue

    table = unique(table, keys=['fieldID'], keep='first')
    table.sort('mjd')

    for index in tiles_struct.keys():
        idx = np.where(table["fieldID"]==index)[0]
        rows = table[idx]
        idx2 = []
        for ii, row in enumerate(rows):
            if row["fId"] in filters:
                idx2.append(ii)
        rows = rows[idx2]

        ipix, corners, patch = tiles_struct[index]["ipix"], tiles_struct[index]["corners"], tiles_struct[index]["patch"]
        #hp.visufunc.projplot(corners[:,0], corners[:,1], 'k', lonlat = True)
        if not patch: continue

        patch_cpy = copy.copy(patch)
        patch_cpy.axes = None
        patch_cpy.figure = None
        patch_cpy.set_transform(ax.transData)
        current_alpha = patch_cpy.get_alpha()

        if current_alpha > 0.0:
            alpha = len(rows)
            if alpha >= 1:
                alpha = 0.5
            patch_cpy.set_alpha(alpha)
            patch_cpy.set_color(color)
        #hp.projaxes.HpxMollweideAxes.add_patch(ax,patch_cpy)
        ax.add_patch(patch_cpy)        

patches = [ plt.plot([],[], marker="s", ms=10, ls="", mec=None, color=colors[i], 
            label="{:s}".format(labels[i]) )[0]  for i in range(len(labels)) ]
plt.legend(handles=patches, loc=8, ncol=5, facecolor="plum", numpoints=1 )

#for transient in transients:
#    hp.visufunc.projplot(transient["ra"], transient["dec"], lonlat=True,
#                         marker="o",markersize=10)
#    hp.visufunc.projtext(transient["ra"]+2.0, transient["dec"], 
#                         transient["name"][-3:],
#                         lonlat=True, fontsize=4, color='w')

plt.axes(ax2)

ipixlist = []
cnt = 0
for kk, telescope in enumerate(tiles_structs):
    labels.append(telescope)
    tiles_struct = tiles_structs[telescope]
    color = colors[kk]

    idx = telescopes.index(telescope)
    fields = fieldsfiles[idx]

    table = Table.read(fields, format="ascii",
                       names=("fieldID","fId","time","limmag","exposure_time"))

    ipixs = np.empty((0,2))
    cum_prob = 0.0
    cum_area = 0.0

    tts, cum_probs, cum_areas = [], [], []

    start_time = Time(opts.gps, format='gps').mjd
    table["mjd"] = Time(table["time"], format='isot').mjd - start_time
    table.sort('mjd')

    table = table[table["mjd"] > 0]
    table = table[table["mjd"] < opts.tmax]

    if len(table) == 0: continue

    table = unique(table, keys=['fieldID'], keep='first')
    table.sort('mjd')

    for row in table:

        for index in tiles_struct.keys():
            idx = np.where(row["fieldID"]==index)[0]
            if len(idx) > 0: break
        ipix, corners, patch = tiles_struct[index]["ipix"], tiles_struct[index]["corners"], tiles_struct[index]["patch"]

        ipixs = np.append(ipixs,ipix)
        ipixs = np.unique(ipixs).astype(int)

        cum_prob = np.sum(map_struct["prob"][ipixs])
        cum_area = len(ipixs) * map_struct["pixarea_deg2"]

        cum_probs.append(cum_prob)
        cum_areas.append(cum_area)
        tts.append(row["mjd"])

        ipixlist.append(ipix.tolist())

    ax2.plot(tts, cum_probs, color=color, linestyle='-', label=telescope)
    #ax3.plot(tts, cum_areas, color=color, linestyle='--')

    if cnt == 0:
        tableall = table
    else:
        tableall = vstack([tableall, table])

    cnt = cnt + 1

tableall["ipix"] = ipixlist

ipixs = np.empty((0,2))
cum_prob = 0.0
cum_area = 0.0

tts, cum_probs, cum_areas = [], [], []

tableall.sort('mjd')
for row in tableall:
    ipix = row["ipix"]
    ipixs = np.append(ipixs,ipix)
    ipixs = np.unique(ipixs).astype(int)

    cum_prob = np.sum(map_struct["prob"][ipixs])
    cum_area = len(ipixs) * map_struct["pixarea_deg2"]

    cum_probs.append(cum_prob)
    cum_areas.append(cum_area)
    tts.append(row["mjd"])

color, telescope = 'k', 'All'
ax2.plot(tts, cum_probs, color=color, linestyle='-', label=telescope)
#ax3.plot(tts, cum_areas, color=color, linestyle='--')

ax2.set_xlabel('Time since event [days]')
ax2.set_ylabel('2D Integrated Probability')
#ax3.set_ylabel('Sky area [sq. deg.]')

#for transient in transients:
#    hp.visufunc.projplot(transient["ra"], transient["dec"], lonlat=True,
#                         marker="o",markersize=10)
#    hp.visufunc.projtext(transient["ra"]+2.0, transient["dec"], 
#                         transient["name"][-3:],
#                         lonlat=True, fontsize=4, color='w')

#gwemopt.plotting.add_edges()
plt.show()
plt.savefig(plotName,dpi=200,bbox_inches='tight')
plt.close('all')


