#!/usr/bin/python

# Copyright (C) 2017 Michael Coughlin
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

""".
Gravitational-wave Electromagnetic Optimization

This script generates an optimized list of pointings and content for
reviewing gravitational-wave skymap likelihoods.

Comments should be e-mailed to michael.coughlin@ligo.org.

"""


import os, sys, glob, optparse, shutil, warnings
import copy
import numpy as np
np.random.seed(0)

from astropy.table import Table
from astropy import units as u
from astropy.coordinates import Angle
from astropy.coordinates import SkyCoord
import ligo.skymap.plot
import healpy as hp

import matplotlib
#matplotlib.rc('text', usetex=True)
matplotlib.use('Agg')
matplotlib.rcParams.update({'font.size': 16})
matplotlib.rcParams['contour.negative_linestyle'] = 'solid'
from matplotlib import patches
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm

import gwemopt.utils

__author__ = "Michael Coughlin <michael.coughlin@ligo.org>"
__version__ = 1.0
__date__    = "6/17/2017"

# =============================================================================
#
#                               DEFINITIONS
#
# =============================================================================

def parse_commandline():
    """@Parse the options given on the command-line.
    """
    parser = optparse.OptionParser(usage=__doc__,version=__version__)

    parser.add_option("-s", "--skymap", help="GW skymap.", default='../data/GW190510/LALInference.fits.gz')
    #parser.add_option("-s", "--skymap", help="GW skymap.", default='../data/GW190426/bayestar190426.fits.gz')

    parser.add_option("-d", "--dataDir", help="data directory.", default ="../data/GW190510")
    parser.add_option("-o", "--outputDir", help="output directory.", default ="../output/GW190510")

    parser.add_option("-t", "--telescopes", help="telescopes.", default ="MASTER_un,MASTER_coadd,TCH")

    parser.add_option("--nside",default=512,type=int)

    parser.add_option("--doZTF",  action="store_true", default=False)

    parser.add_option("-v", "--verbose", action="store_true", default=False,
                      help="Run verbosely. (Default: False)")

    opts, args = parser.parse_args()

    # show parameters
    if opts.verbose:
        print >> sys.stderr, ""
        print >> sys.stderr, "running gwemopt_run..."
        print >> sys.stderr, "version: %s"%__version__
        print >> sys.stderr, ""
        print >> sys.stderr, "***************** PARAMETERS ********************"
        for o in opts.__dict__.items():
          print >> sys.stderr, o[0]+":"
          print >> sys.stderr, o[1]
        print >> sys.stderr, ""

    return opts
    
    

def gen_coverage_one(healpix, ras, decs, FOV_type, FOV, nside):
    
    pixarea = hp.nside2pixarea(nside)
    pixarea_deg2 = hp.nside2pixarea(nside, degrees=True)

    radecs_list = []
    print(FOV_type)
    ipixs = np.array([])
    for ra, dec in zip(ras, decs):
        if FOV_type == "square":
            ipix, radecs, patch, area = gwemopt.utils.getSquarePixels(ra,dec,FOV,nside)
            if len(radecs) == 4:
                idx = [0, 1, 3, 2]
                radecs = radecs[idx,:]
        elif FOV_type == "circle":
            ipix, radecs, patch, area = gwemopt.utils.getCirclePixels(ra,dec,FOV,nside)
        elif FOV_type == "rectangle":
            ipix, radecs, patch, area = gwemopt.utils.getRectanglePixels(ra,dec,FOV[0],FOV[1],nside)
            if len(radecs) == 4:
                idx = [0, 1, 3, 2]
                radecs = radecs[idx,:]       
                radecs = np.vstack((radecs,radecs[0,:]))

        singlecum_prob = np.sum(healpix[ipix])
        singlecum_area = len(ipix) * pixarea_deg2
        print('  probability, Area in tile ('+str(np.round(ra,3))+' , '+str(np.round(dec,3))+') : '+str(np.round(singlecum_prob*100.,3))+' '+str(np.round(singlecum_area,1)))

        ipixs = np.hstack((ipixs,ipix))
        radecs_list.append(radecs)

    ipixs = np.unique(ipixs).astype(int)
    return ipixs, radecs_list
    
def gen_coverage_multiple(healpix,OBS,nside,outputDir):

    rotation = 90
    unit='Gravitational-wave probability'
    cbar=False
    cmap = plt.get_cmap('PuBuGn')

    center2 = SkyCoord('6h', -33.5, unit=['hourangle','deg'])

    plotName = os.path.join(outputDir,'tiles.pdf')
    fig = plt.figure()
    ax = plt.axes([0.1, 0.1, 0.8, 0.8],
        projection='astro zoom',
        center=center2,
        radius=50*u.deg)

    #hp.mollview(healpix,title='',unit=unit,cbar=cbar,
    #            rot=[0,0,0],cmap=cmap)
    ax.imshow_hpx(healpix, cmap='cylon')
    ax = plt.gca()

    ipixs = np.array([])

    pixarea = hp.nside2pixarea(nside)
    pixarea_deg2 = hp.nside2pixarea(nside, degrees=True)
       
    colors=cm.rainbow(np.linspace(0,1,len(OBS)))
 
    for i in np.arange(len(OBS)):
    
        tel_name=OBS[i][0]
        print(tel_name)
        ras=OBS[i][1]
        decs=OBS[i][2]
        FOV_type=OBS[i][3]
        FOV=OBS[i][4]
        
        single_ipixs, radecs_list =gen_coverage_one(healpix, ras, decs, FOV_type, FOV, nside)
        ipixs = np.hstack((ipixs,single_ipixs))
        single_ipixs=np.unique(np.array(single_ipixs).astype(int))
        singlecum_prob = np.sum(healpix[single_ipixs])
        singlecum_area = len(single_ipixs) * pixarea_deg2

        for radecs in radecs_list:
            coords = SkyCoord(np.squeeze(radecs), unit='deg')
            ax.plot(coords.ra, coords.dec, color=colors[i], transform=ax.get_transform('world'), zorder=1)

            poly = patches.Polygon(radecs, transform=ax.get_transform('world'), alpha=0.5, color=colors[i])
            ax.add_patch(poly)

        print(tel_name+'  Total Cumulative Probability, Area '+str(np.round(singlecum_prob*100.,2))+' '+str(np.round(singlecum_area,1)))
        print("********************")
    ax.grid()
    plt.show()
    plt.savefig(plotName,dpi=200)
    plt.close('all')    
                                                                      
    ipixs=np.unique(np.array(ipixs).astype(int))
    cum_prob = np.sum(healpix[ipixs])
    cum_area = len(ipixs) * pixarea_deg2 
    print("********************")  
    print('Total Cumulative Probability, Area '+str(np.round(cum_prob*100.,3))+' '+str(np.round(cum_area,1)))
    #print('Total Cumulative Probability, Area : %.5f, %.5f' % (cum_prob,
                                                              #cum_area))                                                     

    return  

# =============================================================================
#
#                                    MAIN
#
# =============================================================================

warnings.filterwarnings("ignore")

# Parse command line
opts = parse_commandline()
nside = opts.nside

outputDir = opts.outputDir
if not os.path.isdir(outputDir):
    os.makedirs(outputDir)

TEL=opts.telescopes.split(",")
Ra_indice=0
Dec_indice=0

OBS=[]

healpix = hp.read_map(opts.skymap)
healpix = hp.ud_grade(healpix,opts.nside,power=-2)
healpix = healpix / np.sum(healpix)

for tel_name in TEL:
    filename = os.path.join(opts.dataDir,'%s.dat' % tel_name)
    if tel_name=="TCA":
        Ra_indice=3
        Dec_indice=4
        FOV_type="square"
        FOV=1.9
        
    if tel_name=="OAJ":
        Ra_indice=3
        Dec_indice=4
        FOV_type="square"
        FOV=1.4

    if tel_name=="TCH":
        Ra_indice=3
        Dec_indice=4
        FOV_type="square"
        FOV=1.9

    if tel_name=="LesMakes":
        Ra_indice=4
        Dec_indice=5
        FOV_type="rectangle"
        FOV_ra=0.42
        FOV_dec=0.27

    if tel_name=="TRE":
        Ra_indice=3
        Dec_indice=4
        FOV_type="square"
        FOV=4.2

    if tel_name=="CNEOST":
        filename='../data/'+alert+"/"+tel_name+'.dat'
        Ra_indice=1
        Dec_indice=2
        FOV_type="square"
        FOV=3.0
        
    if tel_name=="Xinglong":
        Ra_indice=0
        Dec_indice=1
        FOV_type="square"
        FOV=1.5
        lines = [line.rstrip('\n') for line in open(filename)]
        ras, decs = [], []
        for line in lines:
            lineSplit = line.split(" ")
            thisra  = Angle(lineSplit[0], unit=u.hour).deg
            thisdec = Angle(lineSplit[1], unit=u.deg).deg
            ras.append(thisra)
            decs.append(thisdec)
        ras, decs = np.array(ras), np.array(decs)
        
    if tel_name=="MASTER_un":
        filename = filename.replace("MASTER_un","MASTER")
        Ra_indice=3
        Dec_indice=4
        FOV_type="rectangle"
        FOV_ra=2.05
        FOV_dec=4.1
        lines = [line.rstrip('\n') for line in open(filename)]
        ras, decs = [], []
        for line in lines:
            lineSplit = line.split("|")
            thisra  = Angle(lineSplit[Ra_indice], unit=u.hour).deg
            thisdec = Angle(lineSplit[Dec_indice], unit=u.deg).deg
            if lineSplit[8]!="Coadd":
                ras.append(thisra)
                decs.append(thisdec)
        ras, decs = np.array(ras), np.array(decs)
    
    if tel_name=="MASTER_coadd":
        filename = filename.replace("MASTER_coadd","MASTER")
        Ra_indice=3
        Dec_indice=4
        FOV_type="rectangle"
        FOV_ra=2.05
        FOV_dec=2.05
        lines = [line.rstrip('\n') for line in open(filename)]
        ras, decs = [], []
        for line in lines:
            lineSplit = line.split("|")
            thisra  = Angle(lineSplit[Ra_indice], unit=u.hour).deg
            thisdec = Angle(lineSplit[Dec_indice], unit=u.deg).deg
            print(lineSplit[8])
            if "Coadd" in lineSplit[8]:
                ras.append(thisra)
                decs.append(thisdec)
        ras, decs = np.array(ras), np.array(decs)
    
        
    if tel_name=="HMT":
        Ra_indice=0
        Dec_indice=1
        FOV_type="square"
        FOV=0.75
        lines = [line.rstrip('\n') for line in open(filename)]
        ras, decs = [], []
        for line in lines:
            lineSplit = line.split(" ")
            thisra  = Angle(lineSplit[0], unit=u.hour).deg
            thisdec = Angle(lineSplit[1], unit=u.deg).deg
            ras.append(thisra)
            decs.append(thisdec)
        ras, decs = np.array(ras), np.array(decs)
        
    if tel_name=="HSC":
        Ra_indice=0
        Dec_indice=1
        FOV_type="circle"
        FOV=0.7
        lines = [line.rstrip('\n') for line in open(filename)]
        ras, decs = [], []
        for line in lines:
            lineSplit = line.split(" ")
            thisra  = Angle(lineSplit[1], unit=u.hour).deg
            thisdec = Angle(lineSplit[2], unit=u.deg).deg
            ras.append(thisra)
            decs.append(thisdec)
        ras, decs = np.array(ras), np.array(decs)
        
    if tel_name=="DDOTI":
        Ra_indice=0
        Dec_indice=1
        FOV_type="square"
        FOV=7.0
        lines = [line.rstrip('\n') for line in open(filename)]
        ras, decs = [], []
        for line in lines:
            lineSplit = line.split(" ")
            thisra  = Angle(lineSplit[0], unit=u.hour).deg
            thisdec = Angle(lineSplit[1], unit=u.deg).deg
            ras.append(thisra)
            decs.append(thisdec)
        ras, decs = np.array(ras), np.array(decs)
        
    if tel_name=="SOAR":
        Ra_indice=0
        Dec_indice=1
        FOV_type="circle"
        FOV=0.06
        lines = [line.rstrip('\n') for line in open(filename)]
        ras, decs = [], []
        for line in lines:
            lineSplit = line.split()
            thisra  = Angle(lineSplit[0], unit=u.hour).deg
            thisdec = Angle(lineSplit[1], unit=u.deg).deg
            ras.append(thisra)
            decs.append(thisdec)
        ras, decs = np.array(ras), np.array(decs)
        
    if tel_name=="KMTNet":
        Ra_indice=0
        Dec_indice=1
        FOV_type="square"
        FOV=2.0
        lines = [line.rstrip('\n') for line in open(filename)]
        ras, decs = [], []
        for line in lines:
            lineSplit = line.split()
            thisra  = Angle(lineSplit[1], unit=u.hour).deg
            thisdec = Angle(lineSplit[2], unit=u.deg).deg
            ras.append(thisra)
            decs.append(thisdec)
        ras, decs = np.array(ras), np.array(decs)

    if tel_name not in ["Xinglong","DDOTI","HMT","HSC","SOAR","KMTNet","MASTER_coadd","MASTER_un"]:
        lines = [line.rstrip('\n') for line in open(filename)]
        ras, decs = [], []
        for line in lines:
            lineSplit = list(filter(None, line.split(" ")))
            ras.append(float(lineSplit[Ra_indice]))
            decs.append(float(lineSplit[Dec_indice]))

            #print(ras,decs)

    #print(FOV_type)
    if FOV_type == "rectangle":
        FOV = [FOV_ra,FOV_dec]
    else:
        FOV = FOV
        
    OBS.append([tel_name,ras,decs,FOV_type, FOV])
    
gen_coverage_multiple(healpix,OBS,nside,outputDir)

#gen_coverage(healpix, ras, decs, FOV_type, FOV, )


