#!/usr/bin/python

# Copyright (C) 2017 Michael Coughlin
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.

""".
Gravitational-wave Electromagnetic Optimization

This script generates an optimized list of pointings and content for
reviewing gravitational-wave skymap likelihoods.

Comments should be e-mailed to michael.coughlin@ligo.org.

"""


import os, sys, glob, optparse, shutil, warnings
import copy
import numpy as np
from scipy.stats import norm, rv_discrete
np.random.seed(0)

import healpy as hp
import pandas as pd
from astropy.table import unique, vstack, Table
from astropy import units as u
from astropy.time import Time, TimeDelta
from astropy import cosmology
import astropy.coordinates
from astropy.coordinates import Distance
from astropy.coordinates import SkyCoord
from astropy.coordinates import EarthLocation
from astropy.io import fits

import matplotlib
#matplotlib.rc('text', usetex=True)
matplotlib.use('Agg')
matplotlib.rcParams.update({'font.size': 16})
matplotlib.rcParams['contour.negative_linestyle'] = 'solid'
from matplotlib import patches
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
from matplotlib import cm, colors
from mpl_toolkits.axes_grid1 import make_axes_locatable

import ligo.skymap.distance as ligodist

import gwemopt.utils, gwemopt.plotting
import gwemopt.moc, gwemopt.tiles 
import gwemopt.coverage
import time

try:
    import ligo.skymap.plot
    cmap = "cylon"
except:
    cmap = 'PuBuGn'

__author__ = "Michael Coughlin <michael.coughlin@ligo.org>"
__version__ = 1.0
__date__    = "6/17/2017"

# =============================================================================
#
#                               DEFINITIONS
#
# =============================================================================

def parse_commandline():
    """@Parse the options given on the command-line.
    """
    parser = optparse.OptionParser(usage=__doc__,version=__version__)

    parser.add_option("-c", "--configDirectory", help="GW-EM config file directory.", default ="../config/")
    #parser.add_option("-o", "--outputDir", help="output directory",default="../output/S200115j")
    parser.add_option("-i","--inputDir",default="../input/")
    parser.add_option("-o", "--outputDir", help="output directory",default="../output/serendipitous")

    parser.add_option("-t", "--telescope", help="Telescope.", default ="ZTF")
    #parser.add_option("-g", "--gps", help="Event time GPS.", default=1263097398, type=float)
    parser.add_option("-g", "--gps", help="Event time GPS.", default=1240704018, type=float)

    parser.add_option("--Tobs",default="0.0,1.0")

    parser.add_option("--Tmin",default=0.0,type=float)
    parser.add_option("--Tmax",default=60.0,type=float)
    parser.add_option("--dt",default=0.2,type=float)

    parser.add_option("--observations", help="observation file.", default='../data/serendipitous/ztf_obsfile_status1_backfill_v2_0501_0901.csv')

    parser.add_option("--nside",default=256,type=int)

    parser.add_option("--exposuretimes",default="30.0,30.0")
    parser.add_option("-f","--filters",default="g,r")
    parser.add_option("--max_nb_tiles",default="-1,-1")
    parser.add_option("--mindiff",default=5400.0,type=float)

    parser.add_option("--downsample",default=1,type=int)

    parser.add_option("--doSchedule",  action="store_true", default=False)
    parser.add_option("--doPlots",  action="store_true", default=False)
    parser.add_option("--doMovie",  action="store_true", default=False)
    parser.add_option("--doObservability",  action="store_true", default=False)
    parser.add_option("--doAirmassWeights",  action="store_true", default=False)
    parser.add_option("--NFields",default=-1,type=int)
    parser.add_option("--doEvaluate",  action="store_true", default=False)
    parser.add_option("--doObsFile",  action="store_true", default=False)

    parser.add_option("-v", "--verbose", action="store_true", default=False,
                      help="Run verbosely. (Default: False)")

    opts, args = parser.parse_args()

    # show parameters
    if opts.verbose:
        print >> sys.stderr, ""
        print >> sys.stderr, "running gwemopt_run..."
        print >> sys.stderr, "version: %s"%__version__
        print >> sys.stderr, ""
        print >> sys.stderr, "***************** PARAMETERS ********************"
        for o in opts.__dict__.items():
          print >> sys.stderr, o[0]+":"
          print >> sys.stderr, o[1]
        print >> sys.stderr, ""

    return opts

# =============================================================================
#
#                                    MAIN
#
# =============================================================================

warnings.filterwarnings("ignore")

# Parse command line
opts = parse_commandline()

params = {}
params["config"] = {}
configFiles = glob.glob("%s/*.config"%opts.configDirectory)
for configFile in configFiles:
    telescope = configFile.split("/")[-1].replace(".config","")
    if not telescope in opts.telescope.split(","): continue
    params["config"][telescope] = gwemopt.utils.readParamsFromFile(configFile)
    params["config"][telescope]["telescope"] = telescope
    params["config"][telescope]["tesselation"] = np.loadtxt(params["config"][telescope]["tesselationFile"],usecols=(0,1,2),comments='%')
    params["config"][telescope]["tot_obs_time"] = 1.0

    exposuretime = np.array(opts.exposuretimes.split(","),dtype=np.float)[0]

    params["config"][telescope]["magnitude_orig"] = params["config"][telescope]["magnitude"]
    params["config"][telescope]["exposuretime_orig"] = params["config"][telescope]["exposuretime"]
  
    nmag = -2.5*np.log10(np.sqrt(params["config"][telescope]["exposuretime"]/exposuretime))
    params["config"][telescope]["magnitude"] = params["config"][telescope]["magnitude"] + nmag
    params["config"][telescope]["exposuretime"] = exposuretime  

params["outputDir"] = opts.outputDir
if not os.path.isdir(params["outputDir"]):
    os.makedirs(params["outputDir"])
baseoutputDir = params["outputDir"]

params["tilesType"] = "moc"
params["doMinimalTiling"] = False
params["doParallel"] = False
params["telescopes"] = opts.telescope.split(",")
params["nside"] = opts.nside
params["doChipGaps"] = False
params["doSingleExposure"] = True
params["doAlternatingFilters"] = True
params["doRASlices"] = False
params["doBalanceExposure"] = True
params["powerlaw_n"], params["powerlaw_cl"], params["powerlaw_dist_exp"] = 0.0, 0.9, 0.0
params["gpstime"] = opts.gps
params["mindiff"] = opts.mindiff
params["doMindifFilt"] = True
params["Tobs"] = np.array(opts.Tobs.split(","),dtype=np.float)
params["filters"] = opts.filters.split(",")
params["exposuretimes"] = np.array(opts.exposuretimes.split(","),dtype=np.float)
params["doMaxTiles"] = False
params["doUsePrimary"] = True
params["scheduleType"] = "greedy_slew"
params["airmass"] = 2.0
params["max_nb_tiles"] = np.array(opts.max_nb_tiles.split(","),dtype=np.float)

mjd0 = Time(opts.gps, format='gps').mjd

npix = hp.nside2npix(opts.nside)
theta, phi = hp.pix2ang(opts.nside, np.arange(npix))
prob_data = np.ones((npix,))
prob_data = prob_data / np.sum(prob_data)
ra = np.rad2deg(phi)
dec = np.rad2deg(0.5*np.pi - theta)
radecs = SkyCoord(ra=ra*u.deg, dec=dec*u.deg)

map_struct_1 = {}
map_struct_1["prob"] = prob_data

sort_idx = np.argsort(map_struct_1["prob"])[::-1]
csm = np.empty(len(map_struct_1["prob"]))
csm[sort_idx] = np.cumsum(map_struct_1["prob"][sort_idx])

map_struct_1["cumprob"] = csm
map_struct_1["ipix_keep"] = []
pixarea_deg2 = hp.nside2pixarea(opts.nside, degrees=True)
map_struct_1["pixarea_deg2"] = pixarea_deg2
map_struct_1["ra"] = ra
map_struct_1["dec"] = dec

if opts.doPlots:
    params["outputDir"] = os.path.join(baseoutputDir, "map_1")
    if not os.path.isdir(params["outputDir"]):
        os.makedirs(params["outputDir"])
    print("Plotting skymap...")
    gwemopt.plotting.skymap(params,map_struct_1)

prob_data = np.ones((npix,))
prob_data = prob_data / np.sum(prob_data)
idx = np.where(np.abs(radecs.galactic.b.deg) <= 10.0)[0]
prob_data[idx] = 0.0
prob_data = prob_data / np.sum(prob_data)

map_struct_2 = {}
map_struct_2["prob"] = prob_data

sort_idx = np.argsort(map_struct_2["prob"])[::-1]
csm = np.empty(len(map_struct_2["prob"]))
csm[sort_idx] = np.cumsum(map_struct_2["prob"][sort_idx])

map_struct_2["cumprob"] = csm
map_struct_2["ipix_keep"] = []
pixarea_deg2 = hp.nside2pixarea(opts.nside, degrees=True)
map_struct_2["pixarea_deg2"] = pixarea_deg2
map_struct_2["ra"] = ra
map_struct_2["dec"] = dec

params = gwemopt.utils.params_checker(params)

map_struct = copy.deepcopy(map_struct_1)

Tmin, Tmax, dt = opts.Tmin, opts.Tmax, opts.dt
Tobs = np.arange(Tmin, Tmax+dt, dt)

start_time_og = time.time()
timesfile = open(os.path.join(baseoutputDir,'timesfile.dat'),'w')

if opts.doObservability:
    params["gpstime"] = (Time(opts.gps, format='gps') + TimeDelta(Tmin*u.day)).gps
    observability_struct_1 = gwemopt.utils.observability(params, map_struct)
        
    params["gpstime"] = (Time(opts.gps, format='gps') + TimeDelta((Tmax-7)*u.day)).gps
    observability_struct_2 = gwemopt.utils.observability(params, map_struct)
    moc_structs = gwemopt.moc.create_moc(params)

    for telescope in observability_struct_1.keys():
        observability_map = observability_struct_1[telescope]["observability"]*observability_struct_2[telescope]["observability"]
        idx = np.where(np.abs(radecs.galactic.b.deg) <= 10.0)[0]
        observability_map[idx] = 0.0
        observability_map = observability_map/np.sum(observability_map)
        
        map_struct["prob"] = observability_map
        tile_structs = gwemopt.tiles.moc(params, map_struct, moc_structs,
                                         doSegments=False)

        idxs, probs = [], []
        tiles_struct = tile_structs["ZTF"]
        for index in tiles_struct.keys():
            if tiles_struct[index]["prob"] == 0.0: continue
            idxs.append(index)
            probs.append(tiles_struct[index]["prob"])
        idxs, probs = np.array(idxs), np.array(probs)
        idxsort = np.argsort(probs)[::-1]
        Ntop = idxs[idxsort[:opts.NFields]]
    print(Ntop)
params["gpstime"] = opts.gps

params["outputDir"] = baseoutputDir
schedulefile = os.path.join(params["outputDir"],'schedule_ZTF.dat')
moc_structs = gwemopt.moc.create_moc(params)

map_struct = map_struct_1
Tmin, Tmax, dt = opts.Tmin, opts.Tmax, opts.dt

if not os.path.isfile(schedulefile):
    epochs = {}
    coverage_structs = []
    for ii in range(len(Tobs)-1):
        if np.mod(ii,5) == 0:
            start_time = time.time()
        print(f'Running block {ii}/{len(Tobs)-1}')
        params["doMaxTiles"],params["doBalanceExposure"],params["doAlternatingFilters"] = False,True,True
        params["outputDir"] = os.path.join(baseoutputDir, "%d" % ii)
        moc_structs = gwemopt.moc.create_moc(params)

        if not os.path.isdir(params["outputDir"]):
            os.makedirs(params["outputDir"])

        schedulefileind = os.path.join(params["outputDir"],'schedule_ZTF.dat')
        if os.path.isfile(schedulefileind):
            moc_struct = moc_structs["ZTF"]
            coverage_struct = gwemopt.coverage.read_coverage(params, "ZTF",
                                                             schedulefileind,
                                                             moc_struct=moc_struct)
        else:
            if np.mod(ii,2) == 1:
                schedule_table = pd.read_csv(schedulefile,
                                             delimiter = ' ', header=None,
                                             names = ('field', 'ra', 'dec',
                                                      'mjd', 'mag',
                                                      'exposure_time', 'prob',
                                                      'airmass', 'filt',
                                                      'program_id'))
                idx = np.where((schedule_table['mjd'] >= (mjd0+Tobs[ii]-7.0)) &
                               (schedule_table['mjd'] <= (mjd0+Tobs[ii])))[0]
                schedule_table_field = schedule_table.iloc[idx]
    
                if len(schedule_table_field) > 0:
                    prob_data = np.zeros((npix,))
                    for jj, row2 in schedule_table_field.iterrows():
                        field = row2["field"]
                        if not field in moc_structs["ZTF"]: continue
                        moc_struct = moc_structs["ZTF"][field]
                        ipix = moc_struct["ipix"]
                        prob_data[ipix] = 1.0
                        dt = mjd0 + Tobs[ii] - row2["mjd"]
                        prob_data[ipix] = np.exp(-dt)

                        if row2["program_id"] == 2:
                            if field not in epochs:
                                epochs[field] = []
                            epochs[field].append(row2["mjd"])

                    prob_data[prob_data < 0.01] = 0.01
                    idx = np.where(np.abs(radecs.galactic.b.deg) <= 10.0)[0]
                    prob_data[idx] = 0.0
                    prob_data = prob_data / np.sum(prob_data)
                else:
                    prob_data = np.ones((npix,))            
                    idx = np.where(np.abs(radecs.galactic.b.deg) <= 10.0)[0]
                    prob_data[idx] = 0.0
                    prob_data = prob_data / np.sum(prob_data)
  
                map_struct_2["prob"] = prob_data
                
                sort_idx = np.argsort(map_struct_2["prob"])[::-1]
                csm = np.empty(len(map_struct_2["prob"]))
                csm[sort_idx] = np.cumsum(map_struct_2["prob"][sort_idx])
                
                map_struct_2["cumprob"] = csm
    
                map_struct = map_struct_2
                params["filters"] = ["g","r"]
                params["exposuretimes"] = np.array(opts.exposuretimes.split(","),dtype=np.float)
#                params["exposuretimes"] = [300.0,300.0]
                params["program_id"] = 2
    
                params["doBlocks"] = False
                params["Nblocks"] = 2

                if opts.doPlots:
                    print("Plotting skymap...")
                    gwemopt.plotting.skymap(params,map_struct_2)
            else:
                if os.path.isfile(schedulefile):
                    prob_data = np.ones((npix,))
    
                    schedule_table = pd.read_csv(schedulefile,
                                                 delimiter = ' ', header=None,
                                                 names = ('field', 'ra', 'dec',
                                                          'mjd', 'mag',
                                                          'exposure_time', 'prob',
                                                          'airmass', 'filt',
                                                          'program_id'))
                    idx = np.where((schedule_table['mjd'] >= (mjd0+Tobs[ii]-7.0)) &
                                   (schedule_table['mjd'] <= (mjd0+Tobs[ii])))[0]
                    schedule_table_field = schedule_table.iloc[idx]
    
                    for jj, row2 in schedule_table_field.iterrows():
                        field = row2["field"]
                        if not field in moc_structs["ZTF"]: continue
                        moc_struct = moc_structs["ZTF"][field]
                        ipix = moc_struct["ipix"]
                        dt = mjd0 + Tobs[ii] - row2["mjd"]
                        prob_data[ipix] = 1.0 - np.exp(-dt)
                    prob_data = prob_data / np.sum(prob_data)
                else:
                    prob_data = np.ones((npix,))
                    prob_data = prob_data / np.sum(prob_data)

                map_struct_1["prob"] = prob_data
    
                sort_idx = np.argsort(map_struct_1["prob"])[::-1]
                csm = np.empty(len(map_struct_1["prob"]))
                csm[sort_idx] = np.cumsum(map_struct_1["prob"][sort_idx])
    
                map_struct_1["cumprob"] = csm
    
                map_struct = map_struct_1
#                if opts.doObservability:
#                    params["filters"] = ["g","r"] * 10
#                    params["exposuretimes"] = [30.0,30.0] * 10
                params["filters"] = ["g","r"]
                params["exposuretimes"] = [30.0,30.0]

                params["program_id"] = 1    
    
                params["doBlocks"] = False
                params["Nblocks"] = 2

            params["Tobs"] = np.array([Tobs[ii], Tobs[ii+1]])
            params = gwemopt.segments.get_telescope_segments(params)
    
            cnt = 0
            for telescope in params["telescopes"]:
                if not len(params["config"][telescope]["segmentlist"]) == 0:
                    cnt = cnt + 1
            if cnt == 0:
                print('Skipping round %.2f-%.2f' % (Tobs[ii], Tobs[ii+1]))
                if np.mod(ii+1,5) == 0:
                    print(f"for {Tobs[ii]}:")
                    print(f"--- {time.time() - start_time:.2f} seconds ---")
                    timesfile.write(f"{time.time() - start_time:.2f} \n")
                continue

            config_struct = params["config"]["ZTF"]
            location = EarthLocation(config_struct["longitude"],
                                     config_struct["latitude"],
                                     config_struct["elevation"])

            #use middle time segment to determine RA slice range
            mid = len(config_struct["exposurelist"])/2
            seg = config_struct["exposurelist"][int(np.floor(mid))]

            mjds = np.linspace(seg[0], seg[1], 100)
            tt = Time(mjds, format='mjd', scale='utc', location=location)
            lst = tt.sidereal_time('mean')/u.hourangle
            lst = np.mean(lst)

            #slice within 6-hour-angle range
            ra_low = lst - 3
            ra_high = lst + 3

            if ra_low < 0: ra_low+=24
            if ra_high > 24: ra_high-=24

            ra = map_struct["ra"]
            if ra_low <= ra_high:
                sliced_ipix = np.where((ra_high*360.0/24.0 < ra) | (ra_low*360.0/24.0 > ra))[0]
            else:
                sliced_ipix = np.where((ra_high*360.0/24.0 < ra) & (ra_low*360.0/24.0 > ra))[0]

            map_struct["prob"][sliced_ipix] = 0.0

            if params["program_id"] == 2 and opts.doObservability:
                ipix_top = []
                for field in Ntop:
                    moc_struct = moc_structs["ZTF"][field]
                    if field in epochs:
                        deltat = mjd0 + Tobs[ii] - epochs[field][-1]
                    else:
                        deltat = 99
                    map_struct["prob"][moc_struct["ipix"]] = 1.0 - np.exp(-deltat)
                    ipix_top += list(moc_struct["ipix"])
            
                ipix_bottom = np.delete(np.arange(npix),ipix_top)
                nonzero = np.where(map_struct["prob"] != 0.0)[0]
                intersect = np.intersect1d(nonzero,ipix_bottom)
                map_struct["prob"][intersect] = 0.01

            if opts.doAirmassWeights:
                # get airmass distribution
                theta, phi = hp.pix2ang(opts.nside, np.arange(npix))
                radecs = astropy.coordinates.SkyCoord(
                                      ra=phi*u.rad, dec=(0.5*np.pi - theta)*u.rad)
                observatory = astropy.coordinates.EarthLocation(
                                                lat=config_struct["latitude"]*u.deg,
                                                lon=config_struct["longitude"]*u.deg,
                                                height=config_struct["elevation"]*u.m)
                                                
                # altaz @ halfway point of ~5 hr block
                midpoint = Time(opts.gps, format='gps', scale='utc') + TimeDelta((Tobs[ii] + 0.2 * 0.5) * u.day)
                frame = astropy.coordinates.AltAz(obstime=midpoint,
                                                  location=observatory)
                altaz = radecs.transform_to(frame)
                airmass = altaz.secz
                
                horizon_mask = altaz.alt.degree <= 30.
                below_horizon_mask = horizon_mask * 10.**100
                airmass = airmass + below_horizon_mask
                
                # weight by airmass
                airmass_weights = 10 ** (0.2 * (airmass - 1))
                
                map_struct["prob"] = np.divide(map_struct["prob"],airmass_weights)
            
            # normalize 2D probability
            map_struct["prob"] = map_struct["prob"] / np.sum(map_struct["prob"])

            moc_structs = gwemopt.moc.create_moc(params)
            tile_structs = gwemopt.tiles.moc(params, map_struct, moc_structs)
        
            if opts.doPlots:
                print("Plotting tiles struct...")
                gwemopt.plotting.tiles(params, map_struct, tile_structs)
        
            if opts.doSchedule:
                print("Generating coverage...")
                tile_structs, coverage_struct = gwemopt.coverage.timeallocation(params, map_struct, tile_structs)
            
            #insert filler fields if dt >=5min

            if (coverage_struct["data"].size > 0):
                t2 = coverage_struct["data"][0,2]
                filt2 = coverage_struct["filters"][0]
                exposuretime = coverage_struct["data"][0,4]
                gap_mjds = []
    
                if os.path.isfile(schedulefile) and (not schedule_table.empty):
                    t1 = schedule_table['mjd'][len(schedule_table)-1] + schedule_table['exposure_time'][len(schedule_table)-1]/86400.
                    dt = (t2 - t1) * 1440.
                    if dt >= 5.0:
                        gap_mjds.append(t2)
                
                for jj in range(len(coverage_struct["ipix"])):
                    if jj == 0: continue
                    data = coverage_struct["data"][jj,:]
                    filt1 = coverage_struct["filters"][jj]
    
                    t1 = data[2]
                    dt = (t1 - t2) * 1440.
                    dt = dt - exposuretime/60.
                    if filt1 != filt2 and "filt_change_time" in config_struct:
                        dt -= config_struct["filt_change_time"]
                    if dt >=5.0 and dt < 200:
                        gap_mjds.append(t1)
    
                    exposuretime = data[4]
                    t2 = t1
                    filt2 = filt1

                for mjd in gap_mjds:
                    if params["program_id"] == 2 and opts.doObservability:
                        for field in Ntop:
                            if field not in epochs: continue

                            idx = np.where(coverage_struct["data"][:,5] == field)[0]
                            deltat = mjd0 + Tobs[ii] - epochs[field][-1]
                            if len(epochs[field]) == 1: continue
                            elif (epochs[field][-1] - epochs[field][-2]) < 0.5 and deltat < 0.5:
                                tile_structs["ZTF"][field]["prob"] = tile_structs["ZTF"][field]["prob"]/2

#                        map_struct["prob"] = map_struct["prob"] / np.sum(map_struct["prob"])

                    idx = np.where(coverage_struct["data"][:,2]==mjd)[0][0]
                    exposuretime = coverage_struct["data"][idx,4]
                    filt1 = coverage_struct["filters"][idx]
                    if idx == 0:
                        t1 = schedule_table['mjd'][len(schedule_table)-1] + exposuretime/86400.
                    else:
                        t1 = coverage_struct["data"][idx-1,2] + exposuretime/86400.
                    
                    Tobs1 = t1 - mjd0 #beginning of block
                    Tobs2 = mjd - mjd0 #end of block
                    params["Tobs"] = np.array([Tobs1,Tobs2])
                    params["filters"] = [filt1]
                    params["exposuretimes"] = [params["exposuretimes"][0]]
                    params["doMaxTiles"],params["doBalanceExposure"],params["doAlternatingFilters"] = False,False,False
    
                    params = gwemopt.segments.get_telescope_segments(params)
                    tile_structs["ZTF"] = gwemopt.utils.check_overlapping_tiles(params,tile_structs["ZTF"],coverage_struct)
                    tile_structs, coverage_struct_new = gwemopt.coverage.timeallocation(params, map_struct, tile_structs)
    
                    idx_keep = np.where((coverage_struct_new["data"][:,2] > t1) &
                                        (coverage_struct_new["data"][:,2] < mjd - exposuretime/86400.))[0]
                    
                    coverage_struct["data"] = np.insert(coverage_struct["data"], idx, coverage_struct_new["data"][idx_keep,:],axis=0)
                    coverage_struct["filters"] = np.insert(coverage_struct["filters"], idx, coverage_struct_new["filters"][idx_keep])
                    coverage_struct["FOV"] = np.insert(coverage_struct["FOV"], idx, coverage_struct_new["FOV"][idx_keep])
                    coverage_struct["area"] = np.insert(coverage_struct["area"], idx, coverage_struct_new["area"][idx_keep])
                    coverage_struct["telescope"] = np.insert(coverage_struct["telescope"], idx, coverage_struct_new["telescope"][idx_keep])
                    coverage_struct["ipix"][idx:idx] = [coverage_struct_new["ipix"][jj] for jj in idx_keep]
                    coverage_struct["patch"][idx:idx] = [coverage_struct_new["patch"][jj] for jj in idx_keep]
    
                if opts.doSchedule:
                    print("Summary of coverage...")
                    gwemopt.scheduler.summary(params,map_struct,coverage_struct)
            
                    if opts.doPlots:
                        gwemopt.plotting.coverage(params, map_struct, coverage_struct)
    
        coverage_structs.append(coverage_struct)
#        coverage_struct_combined = gwemopt.coverage.combine_coverage_structs(coverage_structs)
        new_sched_file = os.path.join(params["outputDir"],"schedule_ZTF.dat")
        OG_sched_file = os.path.join(baseoutputDir,"schedule_ZTF.dat")
        
        fin = open(new_sched_file, "r")
        data = fin.read()
        fout = open(OG_sched_file, "a")
        fout.write(data)
        
        fin.close()
        fout.close()
        
#        params["outputDir"] = baseoutputDir

#        if opts.doSchedule:
#            print("Summary of coverage...")
#            gwemopt.scheduler.summary(params,map_struct,coverage_struct_combined)

#            if opts.doPlots:
#                if (np.mod(ii, 15) == 0) or (ii == len(Tobs)-2):
#                    gwemopt.plotting.coverage(params, map_struct, coverage_struct_combined)
        if np.mod(ii+1,5) == 0:
            print(f"for {Tobs[ii]}:")
            print(f"--- {time.time() - start_time:.2f} seconds ---")
            timesfile.write(f"{time.time() - start_time:.2f} \n")

print(f"final time (all): {time.time() - start_time_og:.2f}")
timesfile.write(f"final time: {time.time() - start_time_og:.2f} \n")

timesfile.close()
if opts.doMovie:

    # load the data
    hdulist = fits.open(opts.inputDir+'Gaia_hp8_densitymap.fits')
    hist = hdulist[1].data['srcdens'][np.argsort(hdulist[1].data['hpx8'])]

    bands = {'g': 1, 'r': 2, 'i': 3, 'z': 4, 'J': 5}   
    schedule_table = pd.read_csv(schedulefile, delimiter = ' ', header=None,
                                 names = ('field', 'ra', 'dec', 'mjd',
                                          'exposure_time', 'prob',
                                          'airmass', 'filt', 'program_id'))

    #schedule_table = schedule_table[:100]
    gfields, rfields, grfields = [], [], []
    for ii, row1 in schedule_table.iterrows():
        idx = np.where((schedule_table['mjd'] >= row1['mjd']-1.0) &
                       (schedule_table['mjd'] <= row1['mjd']))[0]
        schedule_table_field = schedule_table.iloc[idx]

        idx = np.where(schedule_table_field["filt"] == "g")[0]
        schedule_table_field_g = schedule_table_field.iloc[idx]

        idx = np.where(schedule_table_field["filt"] == "r")[0]
        schedule_table_field_r = schedule_table_field.iloc[idx]

        gfield = np.unique(schedule_table_field_g["field"])
        rfield = np.unique(schedule_table_field_r["field"])    
        grfield = np.intersect1d(gfield,rfield)

        gfields.append(len(gfield))
        rfields.append(len(rfield))
        grfields.append(len(grfield))

    schedule_table["gfields"] = gfields
    schedule_table["rfields"] = rfields
    schedule_table["grfields"] = grfields

    moc_structs = gwemopt.moc.create_moc(params)

    moviedir = os.path.join(baseoutputDir,'movie')
    if not os.path.isdir(moviedir):
        os.makedirs(moviedir)

    cnt = 0 
    for ii, row1 in schedule_table.iterrows():
        plotName = os.path.join(moviedir,'movie-%04d.png'%cnt)
        if os.path.isfile(plotName):
            cnt = cnt + 1
            continue
        if not np.mod(ii,opts.downsample) == 0: continue

        fig = plt.figure(figsize=(12, 8))

        gs = fig.add_gridspec(4, 1)
        ax1 = fig.add_subplot(gs[0:3, 0], projection='astro hours mollweide')
        ax2 = fig.add_subplot(gs[3, 0])

        plt.axes(ax1)
        # plot the data in healpy
        norm ='log'

        ax1.imshow_hpx(hist, cmap='viridis', norm=LogNorm(), nested=True)
        ax1.grid()

        ra = ax1.coords[0]
        dec = ax1.coords[1]

        ra.set_ticks_visible(False)
        dec.set_ticks_visible(False)

        for jj, row2 in schedule_table.iterrows():
            if ii < jj: continue
            dt = row1["mjd"] - row2["mjd"]
            if dt > 2: continue

            field = row2["field"]
            if not field in moc_structs["ZTF"]: continue
            moc_struct = moc_structs["ZTF"][field]
            ras, decs = moc_struct["corners"][:,0], moc_struct["corners"][:,1]
            if len(ras) == 4:
                ras = [ras[0], ras[1], ras[3], ras[2], ras[0]]
                decs = [decs[0], decs[1], decs[3], decs[2], decs[0]]
                ras, decs = np.array(ras), np.array(decs)
            if row2['filt'] == 'g':
                color = 'g'
            elif row2['filt'] == 'r':
                color = 'r'
            else:
                color = 'k'

            idx = np.where((schedule_table['mjd'] >= row1['mjd']-2.0) &
                           (schedule_table['mjd'] <= row1['mjd']) &
                           (schedule_table['field'] == field))[0]
            schedule_table_field = schedule_table.iloc[idx]

            idx = np.where(schedule_table_field["filt"] == "g")[0]
            schedule_table_field_g = schedule_table_field.iloc[idx]

            idx = np.where(schedule_table_field["filt"] == "r")[0]
            schedule_table_field_r = schedule_table_field.iloc[idx]

            if ii == jj:
                alpha = 1.0
                linewidth = 3.0

                idx1, idx2 = np.where(ras>=180.0)[0], np.where(ras<180.0)[0]
                idx3, idx4 = np.where(ras>300.0)[0], np.where(ras<60.0)[0]
                if (len(idx1)>0 and len(idx2)>0) and not (len(idx3)>0 and len(idx4)>0):
                    alpha = 0.0

                poly = patches.Polygon(np.vstack((ras, decs)).T, transform=ax1.get_transform('world'), alpha=alpha, color=color)
                ax1.plot(ras, decs, color=color, transform=ax1.get_transform('world'), alpha=alpha, linewidth=linewidth)
                ax1.add_patch(poly)
            else:
                alpha = 0.5*(2-dt)*0.5
                linewidth = 2.0*(2-dt)*0.5

                idx1, idx2 = np.where(ras>=180.0)[0], np.where(ras<180.0)[0]
                idx3, idx4 = np.where(ras>300.0)[0], np.where(ras<60.0)[0]
                if (len(idx1)>0 and len(idx2)>0) and not (len(idx3)>0 and len(idx4)>0):
                    alpha = 0.0

                if (len(schedule_table_field_g) > 0) and (len(schedule_table_field_r) > 0):
                    color='magenta'
                poly = patches.Polygon(np.vstack((ras, decs)).T, transform=ax1.get_transform('world'), alpha=alpha, color=color)
                ax1.plot(ras, decs, color=color, transform=ax1.get_transform('world'), alpha=alpha, linewidth=linewidth)
                ax1.add_patch(poly)

        idx = np.where(schedule_table['mjd'] <= row1['mjd'])[0]
        schedule_table_field = schedule_table.iloc[idx]

        ax2.plot(schedule_table_field["mjd"]-mjd0, schedule_table_field["gfields"],
                 '-', color='g',label='g-band')
        ax2.plot(schedule_table_field["mjd"]-mjd0, schedule_table_field["rfields"],
                 '--', color='r',label='r-band')
        ax2.plot(schedule_table_field["mjd"]-mjd0, schedule_table_field["grfields"],
                 ':', color='magenta',label='g+r-bands')
        ax2.legend(loc=1)
        ax2.set_xlabel('Days since MJD=%.5f' % mjd0)
        ax2.set_ylabel('Fields observed last 24 hrs')

        ax2.set_xlim([np.min(schedule_table['mjd'])-mjd0,
                      np.max(schedule_table['mjd'])-mjd0])
        max_area = np.max([np.max(schedule_table['gfields']),
                           np.max(schedule_table['rfields']),
                           np.max(schedule_table['grfields'])])
        ax2.set_ylim([0,max_area])
        fig.suptitle('')

        plotName = os.path.join(moviedir,'movie-%04d.png'%cnt)
        plt.savefig(plotName,dpi=200,bbox_inches='tight')
        plt.close()

        cnt = cnt + 1

    output = "schedule"
    moviefiles = os.path.join(moviedir,"movie-%04d.png")
    filename = os.path.join(moviedir,"%s.mpg" % (output))
    ffmpeg_command = 'ffmpeg -an -y -r 20 -i %s -b:v %s %s'%(moviefiles,'5000k',filename)
    os.system(ffmpeg_command)
    filename = os.path.join(moviedir,"%s.gif" % (output))
    ffmpeg_command = 'ffmpeg -an -y -r 20 -i %s -b:v %s %s'%(moviefiles,'5000k',filename)
    os.system(ffmpeg_command)
    rm_command = "rm %s/*.png"%(moviedir)
    os.system(rm_command)

obsfile = os.path.join(params["outputDir"],'obs.dat')
if opts.doObsFile and not os.path.isfile(obsfile):
    all_obs_table = pd.read_csv(opts.observations, delimiter = ',')
    use_obs = all_obs_table['status'] == 1
    obstable = all_obs_table[use_obs]
    
    params["outputDir"] = baseoutputDir
    
    obsfile = os.path.join(params["outputDir"],'obs.dat')
    fid = open(obsfile, 'w')
    fid.write('jd,fieldid,chid,progid,expid,filterid,limMag,exptime,status\n')
    
    bands = {'g': 1, 'r': 2, 'i': 3, 'z': 4, 'J': 5}
    schedule_table = pd.read_csv(schedulefile, delimiter = ' ', header=None,
                                 names = ('field', 'ra', 'dec',
                                          'mjd', 'mag',
                                          'exposure_time', 'prob',
                                          'airmass', 'filt', 'program_id'))
    
    for ii, row1 in schedule_table.iterrows():
        field = row1["field"]
        mjd = row1["mjd"]
        exptime = row1["exposure_time"]
        filt = row1["filt"]
        filter_id = bands[filt]
        program_id = row1["program_id"]
    
        tt = Time(mjd, format='mjd')
        jd = tt.jd
    
        obstable_field = []
        cnt = 0
        while len(obstable_field) == 0:
            idx = np.where(field-cnt == obstable["fieldid"])[0]
            obstable_field = obstable.iloc[idx]
            if len(obstable_field) == 0:
                print('No data for field %d...' % (field-cnt))
                cnt = cnt+1
    
        jd_groups = obstable_field.groupby('jd')
        limMags = []
        jds = []
        for jj, (g, data1) in enumerate(jd_groups):
            jd_sub = jd_groups.get_group(g)
            limMags.append(np.median(jd_sub["limMag"]))
            jds.append(np.median(jd_sub["jd"]))
        limMags = np.asarray(limMags)
        jds = np.asarray(jds)
        if len(jd_groups) == 1: jd_chosen = jds[0]
        else:
            sortedidx = np.argsort(limMags)[::-1][:int(np.floor(len(jd_groups)/2))]
            if len(sortedidx) == 1:
                idy = sortedidx[0]
            else:
                jds = jds[sortedidx]
                idy = np.random.randint(len(jds))
            jd_chosen = jds[idy]
        obstable_rows = obstable[obstable["jd"] == jd_chosen]
        for ii, row in obstable_rows.iterrows():
            # FIXME
            nmag = np.log(exptime/row.exptime)/np.log(2.5)

            fid.write('%.7f, %d, %d, %d, %d, %d, %.2f, %.0f, %d\n'%(jd, field, int(row.chid), program_id, int(row.expid), filter_id, row.limMag+nmag, exptime, int(row.status)))
    
    fid.close()

if opts.doEvaluate:

    filename = os.path.join(baseoutputDir,'eff.dat')
    if not os.path.isfile(filename):

        nkilonovae = 10000
        mindet = 2
        absmag, dmag = -16, 0.5
        maxdist = 500 # 1 Gpc
        distbins = np.linspace(0,maxdist,1000)
        distprob = distbins ** 2
        distprob = distprob / np.sum(distprob)
    
        distn = rv_discrete(values=(distbins, distprob))
        dists = distn.rvs(size=nkilonovae)
    
        jd_min = (Time(opts.gps, format='gps') + TimeDelta(Tmin*u.day)).jd
        jd_max = (Time(opts.gps, format='gps') + TimeDelta(Tmax*u.day)).jd    
    
        fid = open(filename, 'w')
    
        df = pd.read_csv(obsfile)
        df_group = df.groupby('fieldid')
        fields, effs = [], []
        for ii, (g, data1) in enumerate(df_group):
            if np.mod(ii, 10) == 0:
                print('Process field %d/%d' % (ii, len(df_group)))
            df_sub = df_group.get_group(g)
            df_jd = df_sub.groupby('jd')
            fieldid = np.median(df_sub["fieldid"])
    
            jds, limmags = [], [] 
            for jj, (h, data2) in enumerate(df_jd):
                df_jd_sub = df_jd.get_group(h)
                limMag = np.median(df_jd_sub["limMag"])
                jds.append(np.median(df_jd_sub["jd"]))
                limmags.append(limMag)
            jds = np.array(jds)
            limmags = np.array(limmags)
    
            ndet = 0
            for jj in range(nkilonovae):
                jd = np.random.uniform(jd_min, jd_max)
                dt = jds-jd
                mag = absmag + dmag*dt
                mag[dt < 0] = 99.9 
                dist_threshold = (10**(((limmags-mag)/5.0)+1.0))/1e6
                det = len(np.where(dist_threshold >= dists[jj])[0])
                if mindet <= det:
                    ndet = ndet + 1
    
            fields.append(int(fieldid))
            effs.append(ndet/nkilonovae)
    
            fid.write('%d %.5e\n' % (int(fieldid), ndet/nkilonovae))
        fid.close()

    data_out = np.loadtxt(filename)
    fields, effs = data_out[:,0], data_out[:,1]

    if opts.doPlots:

        cmapnorm = colors.Normalize(vmin=np.min(effs), vmax=np.max(effs))
        cmap = cm.get_cmap('rainbow')

        # load the data
        hdulist = fits.open(opts.inputDir+'Gaia_hp8_densitymap.fits')
        hist = hdulist[1].data['srcdens'][np.argsort(hdulist[1].data['hpx8'])]

        fig = plt.figure(figsize=(12, 8))
        gs = fig.add_gridspec(8, 1)
        ax1 = fig.add_subplot(gs[0:6, 0], projection='astro hours mollweide')
        ax2 = fig.add_subplot(gs[7, 0])

        plt.axes(ax1)
        # plot the data in healpy
        norm ='log'

        ax1.imshow_hpx(hist, cmap='viridis', norm=LogNorm(), nested=True)
        ax1.grid()

        ra = ax1.coords[0]
        dec = ax1.coords[1]

        ra.set_ticks_visible(False)
        dec.set_ticks_visible(False)

        for field, eff in zip(fields, effs):
            field = int(field)
            if not field in moc_structs["ZTF"]: continue
            moc_struct = moc_structs["ZTF"][field]
            ras, decs = moc_struct["corners"][:,0], moc_struct["corners"][:,1]
            if len(ras) == 4:
                ras = [ras[0], ras[1], ras[3], ras[2], ras[0]]
                decs = [decs[0], decs[1], decs[3], decs[2], decs[0]]
                ras, decs = np.array(ras), np.array(decs)

            if not eff > 0: continue
            color = cmap(cmapnorm(eff))
            alpha = 0.5 
            linewidth = 3.0

            idx1, idx2 = np.where(ras>=180.0)[0], np.where(ras<180.0)[0]
            idx3, idx4 = np.where(ras>300.0)[0], np.where(ras<60.0)[0]
            if (len(idx1)>0 and len(idx2)>0) and not (len(idx3)>0 and len(idx4)>0):
                alpha = 0.0

            poly = patches.Polygon(np.vstack((ras, decs)).T, transform=ax1.get_transform('world'), alpha=alpha, color=color)
            ax1.plot(ras, decs, color=color, transform=ax1.get_transform('world'), alpha=alpha, linewidth=linewidth)
            ax1.add_patch(poly)

        cbar = matplotlib.colorbar.ColorbarBase(ax2, cmap=cmap,
                                                norm=cmapnorm,
                                                orientation='horizontal')
        cbar.set_label('Efficiency')

        plotName = os.path.join(baseoutputDir,'eff.pdf')
        plt.savefig(plotName,dpi=200,bbox_inches='tight')
        plt.close()

