#!/usr/bin/python

# Copyright (C) 2017 Michael Coughlin
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

""".
Gravitational-wave Electromagnetic Optimization

This script generates an optimized list of pointings and content for
reviewing gravitational-wave skymap likelihoods.

Comments should be e-mailed to michael.coughlin@ligo.org.

"""


import os, sys, glob, optparse, shutil, warnings
import numpy as np
np.random.seed(0)

import healpy as hp

import matplotlib
#matplotlib.rc('text', usetex=True)
matplotlib.use('Agg')
matplotlib.rcParams.update({'font.size': 16})
matplotlib.rcParams['contour.negative_linestyle'] = 'solid'
import matplotlib.pyplot as plt

import gwemopt.utils, gwemopt.plotting

import ligo.skymap
from ligo.skymap import io
from ligo.skymap import postprocess

__author__ = "Michael Coughlin <michael.coughlin@ligo.org>"
__version__ = 1.0
__date__    = "6/17/2017"

# =============================================================================
#
#                               DEFINITIONS
#
# =============================================================================

def parse_commandline():
    """@Parse the options given on the command-line.
    """
    parser = optparse.OptionParser(usage=__doc__,version=__version__)

    parser.add_option("-s", "--skymap", help="GW skymap.", default='../output/GW190901/skymap.fits')
    parser.add_option("-p", "--skymapplot", help="GW skymap plot.", default='../output/GW190901/skymap.pdf')

    parser.add_option("--nside",default=512,type=int)

    parser.add_option("--doCircle",  action="store_true", default=False)
    parser.add_option("--doCircleZero",  action="store_true", default=False)
    parser.add_option("--doAnnulus",  action="store_true", default=False)
    parser.add_option("--doExcise",  action="store_true", default=False)
    parser.add_option("--excise_file",default="/Users/mcoughlin/Code/ZTF/growth-too-marshal/pixels")
    parser.add_option("--ra",default=120.28)
    parser.add_option("--dec",default=6.35)
    parser.add_option("--radius",default=0.77,type=float)
    parser.add_option("--radius_inner",default=8.242,type=float)
    parser.add_option("--radius_outer",default=2.398,type=float)

    parser.add_option("--doGalactic",  action="store_true", default=False)

    parser.add_option("--doSquare",  action="store_true", default=False)
    #parser.add_option("--ras",default="290.010,290.186,265.423,264.736")
    #parser.add_option("--decs",default="48.528,48.912,43.214,42.455")

    parser.add_option("--doTESS",  action="store_true", default=False)
    parser.add_option("--ras",default="324.5670,338.5766,19.4927,90.0042")
    parser.add_option("--decs",default="-33.1730,-55.0789,-71.9781,-66.5647")
    parser.add_option("--lon",default="10.94,10.94,10.94,10.94")
    parser.add_option("--lat",default="-18.0,-42.0,-66.0,-89.0")

    parser.add_option("--doIntersect",  action="store_true", default=False)
    parser.add_option("-o", "--original_skymap", help="GW skymap.", default='../data/GW190901/bayestar.fits.gz')

    parser.add_option("--doCandidate",  action="store_true", default=False)
    parser.add_option("--ra_candidate",default=255.326375,type=float)
    parser.add_option("--dec_candidate",default=-7.002889,type=float)

    parser.add_option("--doLongitudeCut",  action="store_true", default=False)
    parser.add_option("--longitude_low",default=9.0,type=float)
    parser.add_option("--longitude_high",default=20.0,type=float)

    parser.add_option("--doGetFields",  action="store_true", default=False)
    parser.add_option("--fields_type", default='circle')

    parser.add_option("--level",default=90)

    parser.add_option("--doRotate",  action="store_true", default=False)
    parser.add_option("--theta",default=150.0,type=float)
    parser.add_option("--phi",default=10.0,type=float)

    parser.add_option("-v", "--verbose", action="store_true", default=False,
                      help="Run verbosely. (Default: False)")

    opts, args = parser.parse_args()

    # show parameters
    if opts.verbose:
        print >> sys.stderr, ""
        print >> sys.stderr, "running gwemopt_run..."
        print >> sys.stderr, "version: %s"%__version__
        print >> sys.stderr, ""
        print >> sys.stderr, "***************** PARAMETERS ********************"
        for o in opts.__dict__.items():
          print >> sys.stderr, o[0]+":"
          print >> sys.stderr, o[1]
        print >> sys.stderr, ""

    return opts

def tesselation_spiral(FOV, scale=0.80):
    FOV = np.pi*FOV*FOV*scale

    area_of_sphere = 4*np.pi*(180/np.pi)**2
    n = int(np.ceil(area_of_sphere/FOV))
    print("Using %d points to tile the sphere..."%n)

    golden_angle = np.pi * (3 - np.sqrt(5))
    theta = golden_angle * np.arange(n)
    z = np.linspace(1 - 1.0 / n, 1.0 / n - 1, n)
    radius = np.sqrt(1 - z * z)

    points = np.zeros((n, 3))
    points[:,0] = radius * np.cos(theta)
    points[:,1] = radius * np.sin(theta)
    points[:,2] = z

    ra, dec = hp.pixelfunc.vec2ang(points, lonlat=True)

    return ra, dec

def do_excise(healpix,excise_file):

    ipix = np.loadtxt(excise_file).astype(int)
    healpix[ipix] = 0.0
    healpix = healpix / np.sum(healpix)

    return healpix

def get_longitudecut(healpix, nside, longitude_low, longitude_high):

    npix = hp.nside2npix(nside)
    theta, phi = hp.pix2ang(nside, np.arange(npix))
    ra = np.rad2deg(phi)
    dec = np.rad2deg(0.5*np.pi - theta)

    ipix = np.where((longitude_high*360.0/24.0 < ra) | (longitude_low*360.0/24.0 > ra))[0]
    healpix[ipix] = 0.0
    healpix = healpix / np.sum(healpix)

    return healpix

def get_intersect(healpix,original_skymap,nside):

    original_data = hp.read_map(original_skymap, verbose=True)
    nside_orig = hp.pixelfunc.get_nside(original_data)
    original_data = hp.ud_grade(original_data,nside,power=-2)
    
    healpix = healpix * original_data
    healpix = healpix / np.sum(healpix)

    return healpix    

def gen_annulus(ra,dec,radius,error_inner,error_outer,nside):

    ra, dec = float(ra), float(dec)

    theta = 0.5 * np.pi - np.deg2rad(dec)
    phi = np.deg2rad(ra)

    xyz = hp.ang2vec(theta, phi)
    ipix1 = hp.query_disc(nside, xyz, np.deg2rad(radius-error_inner)) 
    ipix2 = hp.query_disc(nside, xyz, np.deg2rad(radius+error_outer))    
    ipix = np.setdiff1d(ipix2,ipix1)

    npix = hp.nside2npix(nside)
    n = np.zeros(npix)
    n[ipix] = 1.
    healpix = n / np.sum(n)

    return healpix    

def gen_circle(ra,dec,error,nside):

    if type(ra) == str and ',' in ra:
        list_params = np.array([list(map(float,ra.split(","))),list(map(float,dec.split(",")))])
    else:
        list_params = np.array([[float(ra)],[float(dec)]])
    if len(list_params[0]) != len(list_params[1]):
        print("Need equal number of RAs and decs")
        return None
    else:
        params = np.transpose(list_params)
        params = params.tolist()

    npix = hp.nside2npix(nside)
    n = np.zeros(npix)

    for pair in params:
        p = hp.ang2pix(nside, pair[0], pair[1], lonlat=True)
        n[p] = 1.

    healpix = hp.smoothing(n, sigma=np.deg2rad(error), verbose=False)
    healpix = healpix / np.sum(healpix)

    return healpix

def angular_distance(ra1, dec1, ra2, dec2):

    delt_lon = (ra1 - ra2)*np.pi/180.
    delt_lat = (dec1 - dec2)*np.pi/180.
    dist = 2.0*np.arcsin( np.sqrt( np.sin(delt_lat/2.0)**2 + \
         np.cos(dec1*np.pi/180.)*np.cos(dec2*np.pi/180.)*np.sin(delt_lon/2.0)**2 ) )

    return dist/np.pi*180.

def gen_circle_zero(ra, dec, error, nside):

    npix = hp.nside2npix(nside)
    theta, phi = hp.pix2ang(nside, np.arange(npix))
    ras = np.rad2deg(phi)
    decs = np.rad2deg(0.5*np.pi - theta)

    n = np.ones(npix)

    if type(ra) == str and ',' in ra:
        list_params = np.array([list(map(float,ra.split(","))),list(map(float,dec.split(",")))])
    else:
        list_params = np.array([[float(ra)],[float(dec)]])
    if len(list_params[0]) != len(list_params[1]):
        print("Need equal number of RAs and decs")
        return None
    else:
        params = np.transpose(list_params)
        params = params.tolist()

    for pair in params:
        dist = angular_distance(ras, decs, pair[0], pair[1])
        idx = np.where(dist <= error)[0]
        n[idx] = 0
    healpix = n / np.sum(n)
    
    return healpix

def gen_square(ras,decs,nside):

    if ',' in ras:
        print("Use --doCircle for multiple locations")
        return None

    radecs = []
    for r,d in zip(ras,decs):
        radecs.append([r,d])
    radecs = np.array(radecs)

    xyz = []
    for r, d in radecs:
        xyz.append(hp.ang2vec(r, d, lonlat=True))

    npts, junk = radecs.shape
    if npts == 4:
        ipix = hp.query_polygon(nside, np.array(xyz))
    else:
        ipix = hp.query_polygon(nside, np.array(xyz))

    npix = hp.nside2npix(nside)
    n = np.zeros(npix)
    n[ipix] = 1.
    healpix = n / np.sum(n)

    return healpix

#def gen_TESS(ras,decs,nside):
def gen_TESS(lats,lons,nside):

    tess_fov = 24 
    ipixs = np.array([])
    #for r,d in zip(ras,decs):
    for t,n in zip(lats,lons):
        print(n,t)
        ipix, radecs, patch, area = gwemopt.utils.getSquarePixels(n,t,tess_fov,nside)
        #ipix, radecs, patch, area = gwemopt.utils.getSquarePixels(r,d,tess_fov,nside)
        ipixs = np.hstack((ipixs,ipix))
    ipixs = np.unique(ipixs).astype(int)
    
    npix = hp.nside2npix(nside)
    n = np.zeros(npix)
    n[ipixs] = 1.
    healpix = n / np.sum(n)

    r = hp.rotator.Rotator(coord=['E','C'])
    healpix = r.rotate_map(healpix)

    return healpix   

def angular_distance(ra1, dec1, ra2, dec2):

    delt_lon = (ra1 - ra2)*np.pi/180.
    delt_lat = (dec1 - dec2)*np.pi/180.
    dist = 2.0*np.arcsin( np.sqrt( np.sin(delt_lat/2.0)**2 + \
         np.cos(dec1*np.pi/180.)*np.cos(dec2*np.pi/180.)*np.sin(delt_lon/2.0)**2 ) )

    return dist/np.pi*180.

def do_getfields(healpix, FOV=60/3600.0, ra=None, dec=None, radius=None,
                 level=None):

    ras, decs = tesselation_spiral(FOV, scale=0.80)
    if (not ra is None) and (not dec is None):
        dist = angular_distance(ras, decs, ra, dec) 
        idx = np.where(dist <= radius)[0]
        ras, decs = ras[idx], decs[idx]
    elif (not level is None):
        cls = 100 * postprocess.find_greedy_credible_levels(healpix)
        paths = postprocess.contour(cls, [level], degrees=True, simplify=True)
        paths = paths[0]

        pts = np.vstack((ras, decs)).T
        idx = np.zeros((len(ras)))
        for path in paths:
            polygon = matplotlib.path.Path(path)
            check = polygon.contains_points(pts)
            check = list(map(int, check))
            idx = np.maximum(idx, check)
        idx = np.where(idx == 1)[0] 
        ras, decs = ras[idx], decs[idx]

    return ras, decs

# =============================================================================
#
#                                    MAIN
#
# =============================================================================

warnings.filterwarnings("ignore")

# Parse command line
opts = parse_commandline()

if opts.doCircle:
    healpix = gen_circle(opts.ra,opts.dec,opts.radius,opts.nside)
elif opts.doCircleZero:
    healpix = gen_circle_zero(opts.ra,opts.dec,opts.radius,opts.nside)
elif opts.doSquare:
    ras = [float(x) for x in opts.ras.split(",")]
    decs = [float(x) for x in opts.decs.split(",")]
    healpix = gen_square(ras,decs,opts.nside)
elif opts.doAnnulus:
    healpix = gen_annulus(opts.ra,opts.dec,opts.radius,opts.radius_inner,opts.radius_outer,opts.nside)
elif opts.doTESS:
    #ras = [float(x) for x in opts.ras.split(",")]
    #decs = [float(x) for x in opts.decs.split(",")]
    lats = [float(x) for x in opts.lat.split(",")]
    lons = [float(x) for x in opts.lon.split(",")]
    healpix = gen_TESS(lats,lons,opts.nside)
else:
    npix = hp.nside2npix(opts.nside)
    n = np.ones(npix)
    healpix = n / np.sum(n)    

if opts.doIntersect:
    healpix = get_intersect(healpix,opts.original_skymap,opts.nside)

if opts.doLongitudeCut:
    healpix = get_longitudecut(healpix, opts.nside, opts.longitude_low, opts.longitude_high)

if opts.doExcise:
    healpix = do_excise(healpix,opts.excise_file)

if opts.doGetFields:
    if opts.fields_type == "circle":
        ras, decs = do_getfields(healpix, ra=opts.ra, dec=opts.dec, radius=opts.radius)
    elif opts.fields_type == "level":
        ras, decs = do_getfields(healpix, level=opts.level)
    for ra, dec in zip(ras, decs):
        print('%.10f %.10f' % (ra, dec))


#hp.fitsfunc.write_map(opts.skymap,healpix,overwrite=True)
io.write_sky_map(opts.skymap, healpix, moc=True)

if opts.doCandidate:
    npix = hp.nside2npix(opts.nside)
    theta, phi = hp.pix2ang(opts.nside, np.arange(npix))
    ra = np.rad2deg(phi)
    dec = np.rad2deg(0.5*np.pi - theta)

    prob = healpix / np.sum(healpix)
    idy = np.argsort(prob)[::-1]
    cumprob = np.cumsum(prob[idy])
    ipix = hp.ang2pix(opts.nside, opts.ra_candidate, opts.dec_candidate, 
                      lonlat=True)
    idz = np.where(ipix == idy)[0]
    print('Transient is in the %.5f percentile' % (cumprob[idz]*100))

unit='Gravitational-wave probability'
cbar=False

plt.figure()
if opts.doGalactic:
    hp.mollview(healpix,title='',unit=unit,cbar=False,min=0,max=np.max(healpix),coord=['C','G'])
else:
    if opts.doRotate:
        hp.mollview(healpix,title='',unit=unit,cbar=False,min=0,max=np.max(healpix),rot=[opts.theta, opts.phi],coord='C')
    else:
        hp.mollview(healpix,title='',unit=unit,cbar=False,min=0,max=np.max(healpix),coord='C')
    gwemopt.plotting.add_edges()
    # galaxy
    for gallat in [15,0,-15]:
        theta = np.arange(0., 360, 0.036)
        phi = gallat*np.ones_like(theta)
        hp.projplot(theta, phi, 'w-', coord='G',lonlat=True,lw=2)
if opts.doCandidate:
    hp.visufunc.projscatter(opts.ra_candidate, opts.dec_candidate, lonlat=True)
plt.show()
plt.savefig(opts.skymapplot,dpi=200)
plt.close('all')

