import healpy as hp
import numpy as np
import copy
import astropy.coordinates
from astropy.time import Time, TimeDelta
import astropy.coordinates
import astropy.units as u
import matplotlib.pyplot as plt
import seaborn as sns
import os, optparse

def skymap_evol(maplal,mapbaye,outputDir):

    hpx_lal, header = hp.read_map(maplal,verbose=False,h=True)
    hpx_baye, header = hp.read_map(mapbaye,verbose=False,h=True)
    hpx_baye_new, header = hp.read_map(mapbaye,verbose=False,h=True)

    npix_baye = len(hpx_baye)
    nside_baye = hp.npix2nside(npix_baye)
    npix_lal = len(hpx_lal)
    nside_lal = hp.npix2nside(npix_lal)
    
    #sort credible levels for bayestar map
    i = np.flipud(np.argsort(hpx_baye))
    sorted_credible_levels = np.cumsum(hpx_baye[i])
    credible_levels = np.empty_like(sorted_credible_levels)
    credible_levels[i] = sorted_credible_levels

    ipix_baye = [ipix for ipix in np.arange(npix_baye) if credible_levels[ipix]<=0.90]

    #for overlap calculation
    ipix_lal_ref = []
    for ipix in np.arange(npix_lal):
        theta,phi = hp.pix2ang(nside_lal, ipix)
        ipix_ref = hp.ang2pix(nside_baye, theta, phi)
        
        ipix_lal_ref.append(ipix_ref)
    
    intersection = list(set(ipix_baye) & set(ipix_lal_ref))
    prob = sum(hpx_baye[intersection])
    print("intersection of lalinf map with 90% of bayestar: ", prob)


    #for difference calculation
    difs = []
    for ipix in ipix_baye:
        theta,phi = hp.pix2ang(nside_baye, ipix)
        ipix_ref = hp.ang2pix(nside_lal, theta, phi)
        
        hpx_baye_new[ipix] = np.abs(hpx_baye[ipix]-hpx_lal[ipix_ref])
        difs.append(hpx_baye_new[ipix])
        hpx_baye_new[ipix] = np.abs(hpx_baye[ipix]-hpx_lal[ipix_ref])

    #plot cumulative sum of differences
    difs.sort(reverse = True)
    cumsum = np.cumsum(difs)

    sns.set_style("dark")
    plotName = os.path.join(outputDir,'cumsum_difference.pdf')

    plt.plot(np.arange(len(cumsum)),cumsum,c='b')
    plt.title('S200213t Cumulative Sum')
    plt.ylabel("|prob(bayestar) - prob(lalinf)|")
    plt.tight_layout()
    plt.savefig(plotName,dpi=plt.gcf().dpi)
    plt.show()
    plt.close('all')

    #plot difference using mollweide projection
    unit='Gravitational-wave probability'
    cbar=False
    try:
        import ligo.skymap.plot
        cmap = "cylon"
    except:
        cmap = 'PuBuGn'

    plotName = os.path.join(outputDir,'mollview_difference.pdf')
    plt.figure()
    ax = plt.gca()
    hp.mollview(
                hpx_baye_new,title='Difference between 90% of Bayestar and LALInf maps',unit=unit,cbar=cbar,
                min=np.min(hpx_baye_new),max=np.max(hpx_baye_new),
                cmap=cmap)
    add_edges()
    plt.savefig(plotName,dpi=plt.gcf().dpi)
    plt.show()
    plt.close('all')

def observability(map,gpstime,airmass,telescopes,telescope_info,outputDir):
    
    prob, header = hp.read_map(map,field=0,verbose=False,h=True)
    prob = prob / np.sum(prob)
            
    npix = len(prob)
    nside = hp.npix2nside(npix)
    event_time = Time(gpstime, format='gps', scale='utc')
    dts = np.arange(0,7,1.0/24.0)
    dts = np.arange(0,7,1.0/4.0)
    

    # Look up (celestial) spherical polar coordinates of HEALPix grid.
    theta, phi = hp.pix2ang(nside, np.arange(npix))
    # Convert to RA, Dec.
    radecs = astropy.coordinates.SkyCoord(
                                          ra=phi*u.rad, dec=(0.5*np.pi - theta)*u.rad)
    
    obs = {}
    obs["prob"] = copy.deepcopy(prob)
    obs["observability"] = np.zeros((npix,))
    
    for telescope in telescopes:
    
        obs[telescope] = {}
        obs[telescope]["dts"] = {}

        observatory = astropy.coordinates.EarthLocation(lat=telescope_info[telescope][0]*u.deg, lon=telescope_info[telescope][1]*u.deg, height=telescope_info[telescope][2]*u.m)

        for dt in dts:
            time = event_time+TimeDelta(dt*u.day)
            
            frame = astropy.coordinates.AltAz(obstime=time, location=observatory)
            altaz = radecs.transform_to(frame)
            sun_altaz = astropy.coordinates.get_sun(time).transform_to(altaz)

            idx = np.where((altaz.alt >= 30*u.deg) &  (sun_altaz.alt <= -18*u.deg) & (altaz.secz <= airmass))[0]
 
            obs["observability"][idx] = 1

    obs["prob"] = obs["prob"]*obs["observability"]

    unit='Gravitational-wave probability'
    cbar=False
    try:
        import ligo.skymap.plot
        cmap = "cylon"
    except:
        cmap = 'PuBuGn'

    plotName = os.path.join(outputDir,'observability.pdf')
    plt.figure()
    ax = plt.gca()
    hp.mollview(
                obs["prob"],title='',unit=unit,cbar=cbar,
                min=np.min(prob),max=np.max(prob),
                cmap=cmap)
    add_edges()
    plt.savefig(plotName,dpi=plt.gcf().dpi)
    plt.show()
    plt.close('all')

def add_edges():
    
    hp.graticule(verbose=False)
    plt.grid(True)
    lons = np.arange(-150.0,180,30.0)
    lats = np.zeros(lons.shape)
    for lon, lat in zip(lons,lats):
        hp.projtext(lon,lat,"%.0f"%lon,lonlat=True)
    lats = np.arange(-60.0,90,30.0)
    lons = np.zeros(lons.shape)
    for lon, lat in zip(lons,lats):
        hp.projtext(lon,lat,"%.0f"%lat,lonlat=True)



cwd = os.getcwd()
parser = optparse.OptionParser(usage=__doc__,version=1.0)

parser.add_option("-o", "--outputDir", help="output directory",default=f'{cwd}/plots_output')

parser.add_option("--doSkymapEvol", action="store_true", default=False)
parser.add_option("--maplal", help="GW skymap.", default="/Users/mouzaalmualla/Desktop/data/S200105ae/LALInference.fits.gz")
parser.add_option("--mapbaye", help="GW skymap.", default="/Users/mouzaalmualla/Desktop/data/S200105ae/bayestar.fits.gz")

parser.add_option("--doObservability", action="store_true", default=False)
parser.add_option("--observabilitymap", help="GW skymap.", default="/Users/mouzaalmualla/Desktop/data/GW190425/LALInference.fits.gz")
parser.add_option("--telescopes", help="Telescope names.",default ="ZTF")
parser.add_option("--gpstime", help="GPS time.", default=1240215503.0171, type=float)
parser.add_option("--airmass",default=2.5,type=float)

opts, args = parser.parse_args()

if not os.path.isdir(opts.outputDir): os.mkdir(opts.outputDir)

#LATITUDE, LONGITUDE, ELEVATION
telescope_info = {}
telescope_info["ZTF"] = [33.3563,-116.8648,1742.0]
telescope_info["TCA"] = [43.75203,6.92353,1320.0]
telescope_info["TCH"] = [-29.2608,-70.7322,2347.0]
telescope_info["TRE"] = [-21.201387,55.407463,970.0]
telescope_info["OAJ"] = [40.0420111,-1.0161911,1957.0]
telescope_info["F60"] = [40.3942,117.5750,900.0]
telescope_info["Abastunami-T48"] = [41.754021,42.820776,1610.0]
telescope_info["Abastunami-T70"] = [41.754021,42.820776,1610.0]
telescope_info["FZU-Auger"] = [-35.4956928,-69.4494508,2200]
telescope_info["FZU-CTA-N"] = [28.7621233,-17.8899317,2200]
telescope_info["GWAC"] = [40.3942,117.5750,900.0]
telescope_info["OSN"] = [37.0629145,3.3881303,2896]
telescope_info["ShAO-T60"] = [40.782179,48.600159,156.0]
telescope_info["TNT"] = [40.3942,117.5750,900]
telescope_info["VIRT"] = [18.35234,-64.9568,420]
telescope_info["Zadko"] = [-31.356667,115.713611,50.0]


if opts.doObservability:
    observability(opts.observabilitymap,opts.gpstime,opts.airmass,
                  opts.telescopes.split(","),telescope_info,opts.outputDir)

if opts.doSkymapEvol:
    skymap_evol(opts.maplal,opts.mapbaye,opts.outputDir)
