#!python
import matplotlib.pyplot as plt 
import os,sys 
import matplotlib.gridspec as gridspec
import argparse 
import numpy as np
import warnings
warnings.filterwarnings("ignore")

import astropy.units as u
from astropy.coordinates import SkyCoord
from astropy.time import Time
from astroplan import FixedTarget, Observer, EclipsingSystem
from astroplan import (PrimaryEclipseConstraint, is_event_observable, AtNightConstraint, AltitudeConstraint, LocalTimeConstraint)
from astroplan.plots import dark_style_sheet, plot_airmass, plot_sky
from astroplan import moon, Constraint , min_best_rescale



class HAconstraint(Constraint):
    def __init__ (self, min = -12, max = 12, boolean_constraint=True):
        self.min = min 
        self.max = max 
        self.boolean_constraint = boolean_constraint

    def compute_constraint(self, times, observer, targets):
            HA = observer.target_hour_angle(times, FixedTarget(targets)).value
            HA = HA[0][0]
            if HA > 12 : HA -= 24.

            if self.boolean_constraint:
                mask = ((HA > self.min) & (HA < self.max))
                return ~mask

            # if we want to return a non-boolean score
            else:
                # rescale the vega_separation values so that they become
                # scores between zero and one
                rescale = min_best_rescale(HA, self.min,
                                        self.max, less_than_min=0)
                return rescale

description = '''Transit prediciton'''

# Argument parser
parser = argparse.ArgumentParser('predict', description=description)

parser.add_argument('-a', 
                    '--t_zero',
                    help='The transit epoch in arbritraty time units consisting with the input file.', 
                    default=2457500.60872, type=float)

parser.add_argument('-b', 
                    '--period',
                    help='The orbital period in arbritraty time units consisting with the input file.',
                    default=4.70744, type=float)  

parser.add_argument('-c', 
                    '--width',
                    help='The transit width in hrs.', 
                    default=5.8, type=float)

parser.add_argument('-d', 
                    '--ntransits',
                    help='The number of transits to predict.',
                    default=10, type=int) 

parser.add_argument('-e', 
                    '--ra',
                    help='The RA in deg.', 
                    default=205.96079385700, type=float)

parser.add_argument('-f', 
                    '--dec',
                    help='The Dec in deg.', 
                    default=-12.35998970810, type=float)      

parser.add_argument('--plot', action="store_true", default=False, help="Plot each night")
parser.add_argument('--complete', action="store_true", default=False, help="Only complete nights")
parser.add_argument('--LCO', action="store_true", default=False, help="For longitudanal observatories like LCO (or space?). Doesen't check for observability of each object.")


parser.add_argument('-g', 
                '--observatory',
                help='The Observatory.',
                default='SAAO') 

parser.add_argument('-j', 
                '--name',
                help='The target name.',
                default='Star 1')

#
parser.add_argument('-k', 
                '--date',
                help='The date from which to calculate. If not supplied, will default to today. Should be supplied as "2017-01-01"',
                default='now')

if __name__=='__main__':
    # Parse the arguments
    args = parser.parse_args()

    # Define the observatory
    observatory = Observer.at_site(args.observatory)

    # Define the coordinates
    skycoord = SkyCoord(args.ra*u.deg, args.dec*u.deg, frame='icrs')
    fixedtarget = FixedTarget(skycoord, name = args.name)

    primary_eclipse_time = Time(args.t_zero, format='jd')
    orbital_period = args.period * u.day
    eclipse_duration = args.width * u.hour

    # Define the eclipsing target
    eclipsetarget = EclipsingSystem(primary_eclipse_time=primary_eclipse_time,
                                    orbital_period=orbital_period, duration=eclipse_duration,
                                    name=args.name) 

    # Work out the current time
    if args.date == 'now' : obs_time = Time.now()
    else                  : obs_time = Time(args.date + ' 12:00')

    # Calculate the mid-transit times
    midtransit_times = []
    mask_observervable, mask_observervable_entirely =[],[]
    
    # Now do contraints 
    constraints = [AtNightConstraint.twilight_civil(), AltitudeConstraint(min=30*u.deg)]
    
    #if args.observatory == 'SAAO' : constraints.append(HAconstraint(min=-4, max = 5))

    # now calculate transit times and make sure observability
    i_mult = 1
    while len(midtransit_times) < args.ntransits:
        # First get ingres and egress time, along with times of mid-transit
        ingressegress = eclipsetarget.next_primary_ingress_egress_time(obs_time, n_eclipses=i_mult*args.ntransits)
        midtransit_times = np.array(eclipsetarget.next_primary_eclipse_time(obs_time, n_eclipses=i_mult*args.ntransits))

        # Get mask to see if each epoch is observable
        if args.LCO:
            mask_observervable = np.ones(len(midtransit_times), dtype=np.bool)
            mask_observervable_entirely = np.ones(len(midtransit_times), dtype=np.bool)
        else:
            mask_observervable = is_event_observable(constraints, observatory, fixedtarget, times=midtransit_times)[0]
            mask_observervable_entirely = is_event_observable(constraints, observatory, fixedtarget, times_ingress_egress=ingressegress)[0]

        # Finally, mask the epochs to make sure we have enough
        if args.complete : 
            midtransit_times = midtransit_times[mask_observervable_entirely]
            ingressegress = ingressegress[mask_observervable_entirely] 
            mask_observervable_entirely = mask_observervable_entirely[mask_observervable_entirely] 
        else             : 
            midtransit_times = midtransit_times[mask_observervable]
            ingressegress = ingressegress[mask_observervable] 
            mask_observervable_entirely = mask_observervable_entirely[mask_observervable] 

        # Multiple if we don't have enough
        i_mult = i_mult*2

    # Calculate sunrise and set times
    sun_set_times = [ observatory.sun_set_time(midtransit_times[i], which="previous") for i in range(len(midtransit_times)) ]
    sun_rise_times = [observatory.sun_rise_time(midtransit_times[i], which="next") for i in range(len(midtransit_times)) ]
    print('-----------------------------------------------------------------------------------------------------------------------------------------------------')
    print('| Summary of Epochs for {:<15}                                                                                                             |'.format(args.name))
    print('| All times in UTC with airmass given in square brackets                                                                                            |')
    if not args.LCO : print('| Observatory : {:<10}                                                                                                                          |'.format(args.observatory))
    print('|----------------------------------------------------------------------------------------------------------------------------------------------------')
    print('|{:>5} |    {:}    |     {:}        |        {:}         |        {:}         |        {:}         |      {:}       |    {:<6}    |'.format('Epoch', 'date', 'sunset', 'in', 'mid', 'out', 'sunrise', 'Complete transit'))
    print('|------|------------|-------------------|-------------------|--------------------|--------------------|--------------------|------------------------|')
    for i in range(args.ntransits):
        mid_datetime = midtransit_times[i].datetime
        set_datetime = sun_set_times[i].datetime
        rise_datetime = sun_rise_times[i].datetime
        in_datetime = Time(ingressegress[i][0], format='jd').datetime
        out_datetime = Time(ingressegress[i][1], format='jd').datetime


        airmasses = [observatory.altaz(sun_set_times[i], skycoord).secz,
                    observatory.altaz( Time(ingressegress[i][0], format='jd'), skycoord).secz,
                    observatory.altaz( midtransit_times[i], skycoord).secz,
                    observatory.altaz( Time(ingressegress[i][1], format='jd'), skycoord).secz,
                    observatory.altaz(sun_rise_times[i], skycoord).secz]

        print('| {:>3}  | {:>4} {:0>2} {:0>2} | {:0>2}:{:0>2}:{:0>2.0f} [{:>6.2f}] | {:0>2}:{:0>2}:{:0>2.0f} [{:>6.2f}] | {:0>2}:{:0>2}:{:0>2.0f} [{:>6.2f}]  | {:0>2}:{:0>2}:{:0>2.0f} [{:>6.2f}]  | {:0>2}:{:0>2}:{:0>2.0f} [{:>6.2f}]  |           {:}         |'.format(i+1,
                                 mid_datetime.year, mid_datetime.month, mid_datetime.day,
                                 set_datetime.hour, set_datetime.minute, set_datetime.second, airmasses[0],
                                 in_datetime.hour, in_datetime.minute, in_datetime.second, airmasses[1],
                                 mid_datetime.hour, mid_datetime.minute, mid_datetime.second, airmasses[2],
                                 out_datetime.hour, out_datetime.minute, out_datetime.second, airmasses[3],
                                 rise_datetime.hour, rise_datetime.minute, rise_datetime.second, airmasses[4],
                                 mask_observervable_entirely[i]
                                 ))
        # print('|{:>5} | {:>15}'.format(i+1, '{:>4} {:0>2} {:0>2} | {:0>2}:{:0>2}:{:0>2.0f} |       {:}       |'.format(datetime.year, datetime.month, datetime.day, datetime.hour, datetime.minute, datetime.second , mask_observervable_entirely[i]  )))
    print('|---------------------------------------------------------------------------------------------------------------------------------------------------|')
    print('| Target set and rise times for each epoch                                                                                                          |')
    print('|---------------------------------------------------------------------------------------------------------------------------------------------------|')
    print('|{:} |   {:}   |    {:}   |  {:}   | {:}                  |                                     |'.format('Epoch', 'rise', 'set', 'Hour angle information from Sun set    ', 'Moon information'))
    print('|---------------------------------------------------------------------------------------------------------------------------------------------------|')
    print('|                            |  +0hr  |  +2hr  |  +4hr  |  +6hr  |  +8hr  | Angle          | Illumination     |                                     |')
    print('|---------------------------------------------------------------------------------------------------------------------------------------------------|')
    
    #try:
    for i in range(args.ntransits):

        rise_time  = observatory.target_rise_time(midtransit_times[i], fixedtarget, which='previous').datetime
        set_time = observatory.target_set_time(midtransit_times[i], fixedtarget, which='next').datetime
        dt = observatory.target_rise_time(midtransit_times[i], fixedtarget, which='previous') - observatory.target_set_time(midtransit_times[i], fixedtarget, which='next')
        
        
        print('| {:>3}  | {:0>2}:{:0>2}:{:0>2.0f} | {:0>2}:{:0>2}:{:0>2.0f} '.format(i+1,
                rise_time.hour, rise_time.minute, rise_time.second,
                set_time.hour, set_time.minute, set_time.second), end = "")

        for j in np.arange(0, 10, 2):
            new_time = sun_set_times[i] + j*u.hour 
            HA = observatory.target_hour_angle(new_time, fixedtarget).value
            if HA > 12 : HA -= 24.
            print('|  {:0<6.2f} '.format(HA), end='')

        Moon_angle_rise = skycoord.separation(moon.get_moon(observatory.target_rise_time(midtransit_times[i], fixedtarget, which='previous'))).value
        Moon_angle_set  = skycoord.separation(moon.get_moon(observatory.target_set_time(midtransit_times[i], fixedtarget, which='next'))).value

        Moon_illumen_rise = observatory.moon_illumination(observatory.target_rise_time(midtransit_times[i], fixedtarget, which='previous'))*100.
        Moon_illumen_set  = observatory.moon_illumination(observatory.target_set_time(midtransit_times[i], fixedtarget, which='next'))*100.
        print('| {:>5.1f} -> {:>5.1f} | {:>6.2f} -> {:>6.2f} |                                     |'.format(Moon_angle_rise, Moon_angle_set, Moon_illumen_rise, Moon_illumen_set))

    #except: pass
    print('-----------------------------------------------------------------------------------------------------------------------------------------------------')







    # plotting commands, if needed
    if args.plot:
        i = 0

        f = plt.figure()
        ax1 = plt.gca() 

        # plot airmas
        ax2 = plot_airmass(fixedtarget, observatory, midtransit_times[i], style_sheet=dark_style_sheet, ax=ax1, brightness_shading=True, altitude_yaxis=True)
        
        # Calculate airmass

        #ax1.axvspan( (midtransit_times[i] - args.width*u.hour).datetime, (midtransit_times[i] + args.width*u.hour).datetime, ymin=0, ymax=1, color='none', alpha=0.5,hatch="X", edgecolor="b")
        ax1.fill_between( [(midtransit_times[i] - args.width*u.hour).datetime,(midtransit_times[i] + args.width*u.hour).datetime], [3,3], color='blue', alpha=0.2,hatch="X", edgecolor="b")


        ax1.set_title('Gray is nightitme\nBlue hatch is in transit')

        plt.savefig('{:}_Epoch_{:}.png'.format(args.name , i+1))
        plt.close()


        for i in range(args.ntransits):

            f = plt.figure()
            ax1 = plt.gca() 

            # plot airmas
            ax2 = plot_airmass(fixedtarget, observatory, midtransit_times[i], style_sheet=dark_style_sheet, ax=ax1, brightness_shading=True, altitude_yaxis=True)
            
            # Calculate airmass

            #ax1.axvspan( (midtransit_times[i] - args.width*u.hour).datetime, (midtransit_times[i] + args.width*u.hour).datetime, ymin=0, ymax=1, color='none', alpha=0.5,hatch="X", edgecolor="b")
            ax1.fill_between( [(midtransit_times[i] - args.width*u.hour/2).datetime,(midtransit_times[i] + args.width*u.hour/2).datetime], [3,3], color='blue', alpha=0.2,hatch="X", edgecolor="b")

            ax1.set_title('Gray is nightitme\nBlue hatch is in transit')

            plt.savefig('{:}_Epoch_{:}.png'.format(args.name, i+1))
            plt.close()     