#!/home/sam/anaconda3/bin/python
import sys
import logging
import matplotlib.pyplot as plt
import numpy as np
import glob
from astropy.table import Table, Column
import matplotlib.gridspec as gridspec
from scipy.stats import sem 
from astropy.time import Time
import warnings
from scipy.signal import find_peaks
warnings.simplefilter(action='ignore', category=FutureWarning)
np.warnings.filterwarnings('ignore')
plt.rcParams.update({'font.size': 7})





def round_of_rating(number):
    """Round a number to the closest half integer.
    >>> round_of_rating(1.3)
    1.5
    >>> round_of_rating(2.6)
    2.5
    >>> round_of_rating(3.0)
    3.0
    >>> round_of_rating(4.1)
    4.0"""

    return round(number * 2) / 2


def lc_bin(time, flux, bin_width):
    '''
    Function to bin the data into bins of a given width. time and bin_width 
    must have the same units
    '''
    
    edges = np.arange(np.min(time), np.max(time), bin_width)
    dig = np.digitize(time, edges)
    time_binned = (edges[1:] + edges[:-1]) / 2
    flux_binned = np.array([np.nan if len(flux[dig == i]) == 0 else flux[dig == i].mean() for i in range(1, len(edges))])
    err_binned = np.array([np.nan if len(flux[dig == i]) == 0 else sem(flux[dig == i]) for i in range(1, len(edges))])
    time_bin = time_binned[~np.isnan(err_binned)]
    err_bin = err_binned[~np.isnan(err_binned)]
    flux_bin = flux_binned[~np.isnan(err_binned)]   
    
    return time_bin, flux_bin, err_bin



if __name__=="__main__":
    # First, parse the action ID 
    action = sys.argv[1]
        
    print('Processing action : ', action)

    ###########################################################
    # Section 1
    # Before selecting an aperture, we need to look for bad
    # comparison stars. To do this, we will look at the 3.5 
    # pixel aperture only. 
    ############################################################
    aperture_to_use = '3.5'

    # Let's load the first action to see how many comparison stars there are
    phot_table = np.genfromtxt('{:}.phot{:}'.format(action, aperture_to_use))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32 
    number_of_coparison_stars = (phot_table.shape[1] - (7+7)) // 7 

    # Now create the bad comparison star mask
    bad_star_mask = np.zeros(number_of_coparison_stars)

    # Now create the bad night flag
    night_flag = False
    
    # Now set up the figure to plot 
    fig1, ax = plt.subplots(nrows = number_of_coparison_stars, ncols=1, figsize=(5,50))

    for j in range(number_of_coparison_stars):
        median = np.median(phot_table[:, 8+2 + j*7])
        std = np.std(phot_table[:, 8+2 + j*7])


        if (np.min(phot_table[:, 8+2 + j*7] ) < 0) : bad_star_mask[j] = 1 # catch wierd cases with low photometry (faint)
        if len(np.where(phot_table[:, 8 + j*7]< 10)[0]) > 5 : bad_star_mask[j] = 1
        if len(np.where(phot_table[:, 8+1 + j*7] < 10)[0]) > 5 : bad_star_mask[j] = 1

        #if np.percentile(phot_table[:, 8+7+6 + j*7], [70])[0] > 30000 : bad_star_mask[i,j] = 2 # catch wierd cases with low photometry (faint)

        mask = (phot_table[:, 8+2 + j*7] < median - 5*std) | (phot_table[:, 8+2 + j*7] > median + 5*std)
        ax[j].scatter(phot_table[:, 0], phot_table[:, 8+2 + j*7]/phot_table[:, -5], alpha = 0.1, c='k', s= 5)
        x_, y_, e_ = lc_bin(phot_table[:, 0], phot_table[:, 8+2 + j*7]/phot_table[:, -5], 0.25/24)
        ax[j].scatter(x_, y_, alpha =1 , c='r', s= 10)

        ax[j].set_ylim(*np.percentile(phot_table[:, 8+2 + j*7]/phot_table[:, -5], [1,99]))
        ax[j].set_xticks([]); ax[j].set_yticks([])
        ax[j].set_ylabel(str(j+1), rotation=0) #

    if np.sum(bad_star_mask) > int(0.75*number_of_coparison_stars):
        bad_star_mask = np.zeros(number_of_coparison_stars)
        night_flag = True

    for j in range(number_of_coparison_stars):
        if bad_star_mask[j] == 1: # These are systematicall shit nights
            ax[j].set_facecolor('xkcd:salmon')
            ax[j].set_facecolor((1.0, 0.47, 0.42)) 

        if bad_star_mask[j] == 2: # These are systematicall shit nights
            ax[j].set_facecolor('xkcd:crimson')
            ax[j].set_facecolor((1.0, 0.47, 0.42)) 

    plt.tight_layout()
    plt.savefig('{:}_comparison_star_summary.png'.format(action))
    plt.close()


    bad_star_mask = np.array([~j.sum(axis=0).astype(np.bool) for j in  bad_star_mask],dtype = np.bool)
    print('\tA total of {:} unique stars have been rejected'.format(np.sum(~bad_star_mask, axis=0).sum()))









    ###########################################################
    # Section 2
    # Now we have the bad comparison stars, we can start to 
    # go through each acion, sum the flux of the good 
    # good comparisons, and estimate the best aperture.
    ############################################################
    print('\nFinding the  best aperture')

    # Get the number of apertures
    apertures = [j.split('.')[1][4:] +'.' +  j.split('.')[2] for j in glob.glob('{:}.phot*'.format(action))]
    apertures.sort(key=float)
    print('\t{:} apertures found'.format(len(apertures)))
    print('\t{:}'.format(apertures))

    # Now set up the figure to plot 
    fig1 = plt.figure(constrained_layout=False, figsize=(5,20))
    gs1 = gridspec.GridSpec(len(apertures), 2, figure=fig1)
    ax_count = 0 
    rms = np.empty(len(apertures))  

    for j in range(len(apertures)):
        # Load the photometry table
        phot_table = np.genfromtxt('{:}.phot{:}'.format(action, apertures[j]))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32 

        # Modify the JD 
        phot_table[:, 0] = phot_table[:, 0] - int(np.min(phot_table[:, 0]))

        # 7 admin rows : JD-MID BJD-TDB-MID HJD-MID EXPTIME AIRMASS FWHM AGERRX AGERRY
        # 7 for the target :  X Y FLUX FLUXERR SKY SKYERR MAXPIX
        # 7 for the comps  :  X Y FLUX FLUXERR SKY SKYERR MAXPIX

        # Let's create a sigma clip mask just in case
        comparison_flux =  phot_table[:, (8+2)::7].T[:-1][bad_star_mask].T.sum(axis=1) / phot_table[:,-5]
        std = np.std(comparison_flux)
        sigma_clip_mask = (comparison_flux < comparison_flux - 10*std) | (comparison_flux > comparison_flux + 10*std) 

        
        # Before we fit a line, we want to do a bin of the nights lightcurve and look at the variance.
        # It might be that there is cloud which we might want to exclude
        # lets do a 10-minute bin
        N_cloud_thresh = 0.1
        time_bin, mag_bin, mag_bin_err = lc_bin(phot_table[:, 0], comparison_flux, 0.25/24)
        if (time_bin.shape[0] > 4): # incase the time axis is too short
            if np.max(mag_bin_err) > N_cloud_thresh:
                noisy_data_times = time_bin[np.where(mag_bin_err > N_cloud_thresh)]
                if noisy_data_times.shape[0] == 0 :
                    noisy_data_start = noisy_data_times[0] - (10./24./60)
                    noisy_data_end   = noisy_data_times[0] + (10./24./60)
                else:
                    noisy_data_start = np.min(noisy_data_times) - (10./24./60)
                    noisy_data_end   = np.max(noisy_data_times) + (10./24./60)                    
                sigma_clip_mask = sigma_clip_mask | ((phot_table[:, 0] > noisy_data_start) & (phot_table[:, 0] < noisy_data_end))
        

        # Let's check the sigma_clip_mask, if everything's masked then let's revert it 
        if phot_table[:, 0][~sigma_clip_mask].shape[0] == 0 : sigma_clip_mask = ~sigma_clip_mask # in case all is masked, don't use it

        # Let's calculate weights and do a polynomial fit
        weights = comparison_flux[~sigma_clip_mask]**-2*np.sum(comparison_flux[~sigma_clip_mask]**-2)
        model = np.poly1d(np.polyfit(phot_table[:, 0][~sigma_clip_mask], comparison_flux[~sigma_clip_mask], 2, w = weights ))(phot_table[:, 0][~sigma_clip_mask])

        # Now calculate the RMS 
        rms[j] = np.std(comparison_flux[~sigma_clip_mask] / model)

        # Now let's plot
        ax = plt.subplot(gs1[j, 0])
        ax.scatter(phot_table[:, 0][~sigma_clip_mask], comparison_flux[~sigma_clip_mask], alpha = 1, c='k', s= 5)
        ax.plot(phot_table[:, 0][~sigma_clip_mask], model, 'b--')
        ax.set_xticks([]); ax.set_yticks([])
        ax.set_ylabel('{:} pix'.format(apertures[j]), fontsize=7)
        ax2 = ax.twinx()
        ax2.set_ylabel('{:,} ppm'.format(int(1e6*rms[j])))

        # Now verbose and set up for next one
        print('\t\t{:>8} {:>5} has an rms of {:>8,} ppm [{:>3.1f}% exclusion]'.format('Aperture', apertures[j], int(1e6*rms[j]), 100*(np.sum(mask)/mask.shape[0]) ))

    # Now plot RMS vs aperture on the LHS of graph (in ppm)
    ax = plt.subplot(gs1[:, 1])
    ax.plot(1e6*rms, apertures, 'b')
    ax.set_ylim(16,0)
    ax.set_xlabel('RMS [ppm]')
    ax.set_ylabel('Aperture size [pixel]')

    best_aperture = round_of_rating(float(apertures[np.argmin(rms)]))
    best_aperture_string=''     
    for i in range(len(apertures)) : 
        if best_aperture == np.array(apertures).astype(float)[i] : best_aperture_string = apertures[i]      # <-- The string best aperture       
    
    print('\t\tBest aperture for action: {:} pixels'.format(best_aperture_string))
    ax.axhline(best_aperture-1, c='b', ls='--')
    plt.suptitle('Action {:}'.format(action))
    fig1.tight_layout(rect=[0, 0.03, 1, 0.95])
    fig1.savefig('{:}_aperture_rms.png'.format(action))
    plt.close(fig1)

    f = open('.monoaperture',"w+")
    f.write('{:}'.format(best_aperture_string))
    f.close()