#!/home/sam/anaconda3/bin/python

"""
STARTING OVER WITH A COMPLETE REWRITE

PROCESS:
    Take in which object to analyse
    Go find its folder and phot output

    For each action
        For each aperture size
            For each comparison star
                work out the RMS/quality
                flag any potentially bad stars

            Make the light curve
            Analyse the RMS (OOT if transit) in data
        Make a final light curve from the best aperture and comparisons
"""
import sys
import logging
import matplotlib.pyplot as plt
import numpy as np
import glob
from astropy.table import Table, Column
import matplotlib.gridspec as gridspec
from scipy.stats import sem 
from astropy.time import Time
import warnings
from scipy.signal import find_peaks
warnings.simplefilter(action='ignore', category=FutureWarning)
np.warnings.filterwarnings('ignore')

plt.rcParams.update({'font.size': 7})


def round_of_rating(number):
    """Round a number to the closest half integer.
    >>> round_of_rating(1.3)
    1.5
    >>> round_of_rating(2.6)
    2.5
    >>> round_of_rating(3.0)
    3.0
    >>> round_of_rating(4.1)
    4.0"""

    return round(number * 2) / 2


def lc_bin(time, flux, bin_width):
        '''
        Function to bin the data into bins of a given width. time and bin_width 
        must have the same units
        '''
        
        edges = np.arange(np.min(time), np.max(time), bin_width)
        dig = np.digitize(time, edges)
        time_binned = (edges[1:] + edges[:-1]) / 2
        flux_binned = np.array([np.nan if len(flux[dig == i]) == 0 else flux[dig == i].mean() for i in range(1, len(edges))])
        err_binned = np.array([np.nan if len(flux[dig == i]) == 0 else sem(flux[dig == i]) for i in range(1, len(edges))])
        time_bin = time_binned[~np.isnan(err_binned)]
        err_bin = err_binned[~np.isnan(err_binned)]
        flux_bin = flux_binned[~np.isnan(err_binned)]   
        
        return time_bin, flux_bin, err_bin



# TODO Add method of rejecting frames rather than stars if many stars have failed phot
# TODO Add docstrings
# pylint: disable = invalid-name
# pylint: disable = redefined-outer-name
# pylint: disable = no-member
# pylint: disable = too-many-locals
# pylint: disable = too-many-arguments
# pylint: disable = unused-variable
# pylint: disable = line-too-long

# default plot settings
plt.rc('legend', **{'fontsize':10})

if __name__ == "__main__":
    # First, let's get the actions for this object 
    actions = np.unique([i.split('.')[0] for i in glob.glob('*.phot*')])
    number_per_action = [len(glob.glob('{:}.phot*'.format(i))) for i in actions]
    
    if actions.shape[0] ==0 : 
        raise NameError('No actions to reduce.')
        exit()
    else:
        print('{:} actions found:'.format(actions.shape[0]))
        for i in range(actions.shape[0]) : print('\t{:} with {:>3} apertures'.format(actions[i], number_per_action[i]))




    print('\n\nSearching for inhomegenous reference stars where the field has been reset')
    aperture_to_use = '3.5'

    f = plt.figure()
    phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[i], aperture_to_use))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32 
    X_ref = phot_table[50, 8::7][:10]
    Y_ref = phot_table[50, 9::7][:10]
    diff_=[0]
    diff_ref_regions = []
    idx_ref = 0
    diff_thresh = 70000

    i = 1
    while i < len(actions):
        phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[i], aperture_to_use))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32 
        diff =  np.sum(np.abs(phot_table[0, 8::7][:10] -X_ref)) +  np.sum(np.abs(phot_table[0, 9::7][:10] -Y_ref))
        diff_.append(diff)
        X_ref = phot_table[50, 8::7][:10]
        Y_ref = phot_table[50, 9::7][:10]
        if (diff > diff_thresh) or (abs(number_per_action[i] - number_per_action[i-1]) > 1):
            diff_ref_regions.append([idx_ref, i-1])
            idx_ref = i 
            i +=1 # add to skip followin hit
            phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[i-2], aperture_to_use))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32 
            X_ref = phot_table[50, 8::7][:10]
            Y_ref = phot_table[50, 9::7][:10]
        i+=1 

        if i == len(actions) : diff_ref_regions.append([idx_ref, i-1])

    if len(diff_ref_regions)==0 : diff_ref_regions = [[0, len(actions)]]
    print('\tFound {:} regions with different refernece stars... '.format(len(diff_ref_regions)))
    if len(diff_ref_regions) > 1 : 
        for i in range(len(diff_ref_regions)) : print('\t\tBetween actions ', actions[diff_ref_regions[i][0]],' and ',actions[diff_ref_regions[i][1]])
    else : print('All appear to be the same')
    plt.plot(range(len(diff_)), diff_)
    plt.axhline(np.median(diff_))
    if len(diff_ref_regions) > 1 :
        for i in range(len(diff_ref_regions)) : plt.axvline(diff_ref_regions[i][1], ls='--')

    plt.savefig('ref_XY_diff.png')
    plt.close()


    print('\nSearching for bad comparison stars in the {:} pixel aperture for {:} reference images'.format(aperture_to_use, len(diff_ref_regions)))

    ###########################################################
    # Section 1
    # Before selecting an aperture, we need to look for bad
    # comparison stars. To do this, we will look at the 3.5 
    # pixel aperture only. 
    ############################################################

    # Let's load the first action to see how many comparison stars there are
    phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[0], aperture_to_use))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32 
    number_of_coparison_stars = 10 #(phot_table.shape[1] - (7+7)) // 7 

    # Now create the bad comparison star mask
    bad_star_mask = np.zeros((len(diff_ref_regions), actions.shape[0], number_of_coparison_stars))

    diff_ref_regions_idx = None 

    # Create the bad night flag 
    bad_night = np.zeros(actions.shape[0], dtype = np.bool)

    # Now cycle apertures
    for i in range(actions.shape[0]):
        # First, we need to find which diff_ref_region with aperture belogs to
        for j in range(len(diff_ref_regions)):
            if (i >= diff_ref_regions[j][0]) and (i< diff_ref_regions[j][1]) : 
                diff_ref_regions_idx = j 
                break
        
        # Get the number of apertures
        apertures = [j.split('.')[1][4:] +'.' +  j.split('.')[2] for j in glob.glob('{:}.phot*'.format(actions[i]))]
        apertures.sort(key=float)

        # Now set up the figure to plot 
        fig1, ax = plt.subplots(nrows = number_of_coparison_stars, ncols=1, figsize=(5,50))

        phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[i], aperture_to_use))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32 


        for j in range(number_of_coparison_stars):
            median = np.median(phot_table[:, 8+2 + j*7])
            std = np.std(phot_table[:, 8+2 + j*7])

            if np.sum(phot_table[:, 8 + j*7] < 10) > 1 : bad_star_mask[diff_ref_regions_idx,i,j] = 1
            if np.sum(phot_table[:, 8+1 + j*7] < 10) > 1 : bad_star_mask[diff_ref_regions_idx,i,j] = 1
            if np.sum(phot_table[:, 8+2 + j*7] < 10) > 1 : bad_star_mask[diff_ref_regions_idx,i,j] = 1

            mask = (phot_table[:, 8+2 + j*7] < median - 5*std) | (phot_table[:, 8+2 + j*7] > median + 5*std)
            ax[j].scatter(phot_table[:, 0], phot_table[:, 8+2 + j*7]/phot_table[:, -5], alpha = 0.1, c='k', s= 5)
            x_, y_, e_ = lc_bin(phot_table[:, 0], phot_table[:, 8+2 + j*7]/phot_table[:, -5], 0.25/24)
            ax[j].scatter(x_, y_, alpha =1 , c='r', s= 10)

            ax[j].set_ylim(*np.percentile(phot_table[:, 8+2 + j*7]/phot_table[:, -5], [1,99]))
            ax[j].set_xticks([]); ax[j].set_yticks([])
            ax[j].set_ylabel(str(j+1), rotation=0) #

            if bad_star_mask[diff_ref_regions_idx,i,j] == 1: # These are systematicall shit nights
                ax[j].set_facecolor('xkcd:salmon')
                ax[j].set_facecolor((1.0, 0.47, 0.42)) 

            if bad_star_mask[diff_ref_regions_idx,i,j] == 2: # These are systematicall shit nights
                ax[j].set_facecolor('xkcd:crimson')
                ax[j].set_facecolor((1.0, 0.47, 0.42)) 
        if np.sum(bad_star_mask[diff_ref_regions_idx,i])==10 : 
            bad_star_mask[diff_ref_regions_idx,i] = np.zeros(number_of_coparison_stars, dtype = np.bool)
            bad_night[i] = True 

        print('\tAction {:} has {:} bad comparison star(s) in reference image {:}'.format(actions[i], int(bad_star_mask[diff_ref_regions_idx].sum(axis=1)[i]), diff_ref_regions_idx+1))
        
        plt.tight_layout()
        plt.savefig('Action_{:}_comparison_star_summary_first.png'.format(actions[i]))
        plt.close()


    # Now convert badstarmask into bool 
    bad_star_mask = np.array([~j.sum(axis=0).astype(np.bool) for j in  bad_star_mask],dtype = np.bool)

    for i in range(len(bad_star_mask)):
        print(np.sum(bad_star_mask[i]))
        if np.sum(bad_star_mask[i])==0: 
            bad_star_mask[i] = ~bad_star_mask[i]
    print('\tA total of {:} unique stars have been rejected'.format(np.sum(~bad_star_mask, axis=0).sum()))

    
    ###########################################################
    # Section 2
    # Now we have the bad comparison stars, we can start to 
    # go through each acion, sum the flux of the good 
    # good comparisons, and estimate the best aperture.
    ############################################################
    print('\nSearching for best aperture using best filtered comparison stars')
    best_apertures = np.empty(actions.shape[0])
    for i in range(actions.shape[0]):
        if bad_night[i]:
            print('Skipping action {:} since it looks like a shit night'.format(actions[i]))
        # First, we need to find which diff_ref_region with aperture belogs to
        for j in range(len(diff_ref_regions)):
            if (i >= diff_ref_regions[j][0]) and (i< diff_ref_regions[j][1]) : 
                diff_ref_regions_idx = j 
                break

        # Get the number of apertures
        apertures = [j.split('.')[1][4:] +'.' +  j.split('.')[2] for j in glob.glob('{:}.phot*'.format(actions[i]))]
        apertures.sort(key=float)

        print('\tProcessing action {:}'.format(actions[i]))
        # Now set up the figure to plot 
        fig1 = plt.figure(constrained_layout=False, figsize=(5,20))
        gs1 = gridspec.GridSpec(len(apertures), 2, figure=fig1)
        ax_count = 0 
        rms = np.empty(len(apertures))  

        for j in range(len(apertures)):

            # Load the photometry table
            phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[i], apertures[j]))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32 

            # Modify the JD 
            phot_table[:, 0] = phot_table[:, 0] - int(np.min(phot_table[:, 0]))

            # 7 admin rows : JD-MID BJD-TDB-MID HJD-MID EXPTIME AIRMASS FWHM AGERRX AGERRY
            # 7 for the target :  X Y FLUX FLUXERR SKY SKYERR MAXPIX
            # 7 for the comps  :  X Y FLUX FLUXERR SKY SKYERR MAXPIX

            # Let's create a sigma clip mask just in case
            comparison_flux =  phot_table[:, (8+2)::7].T[:-1][:10][bad_star_mask[diff_ref_regions_idx]].T.sum(axis=1) / phot_table[:,-5]
            #ass.scatter(phot_table[:,0],phot_table[:, (8+2)::7].T[:-1][bad_star_mask].T.sum(axis=1), s=10)


            std = np.std(comparison_flux)
            sigma_clip_mask = (comparison_flux < comparison_flux - 10*std) | (comparison_flux > comparison_flux + 10*std) 

            
            # Before we fit a line, we want to do a bin of the nights lightcurve and look at the variance.
            # It might be that there is cloud which we might want to exclude
            # lets do a 10-minute bin
            N_cloud_thresh = 0.1
            time_bin, mag_bin, mag_bin_err = lc_bin(phot_table[:, 0], comparison_flux, 0.25/24)
            if (time_bin.shape[0] > 4): # incase the time axis is too short
                if np.max(mag_bin_err) > N_cloud_thresh:
                    noisy_data_times = time_bin[np.where(mag_bin_err > N_cloud_thresh)]
                    if noisy_data_times.shape[0] == 0 :
                        noisy_data_start = noisy_data_times[0] - (10./24./60)
                        noisy_data_end   = noisy_data_times[0] + (10./24./60)
                    else:
                        noisy_data_start = np.min(noisy_data_times) - (10./24./60)
                        noisy_data_end   = np.max(noisy_data_times) + (10./24./60)                    
                    sigma_clip_mask = sigma_clip_mask | ((phot_table[:, 0] > noisy_data_start) & (phot_table[:, 0] < noisy_data_end))
            

            # Let's check the sigma_clip_mask, if everything's masked then let's revert it 
            if phot_table[:, 0][~sigma_clip_mask].shape[0] == 0 : sigma_clip_mask = ~sigma_clip_mask # in case all is masked, don't use it

            # Let's calculate weights and do a polynomial fit
            weights = comparison_flux[~sigma_clip_mask]**-2*np.sum(comparison_flux[~sigma_clip_mask]**-2)
            model = np.poly1d(np.polyfit(phot_table[:, 0][~sigma_clip_mask], comparison_flux[~sigma_clip_mask], 2, w = weights ))(phot_table[:, 0][~sigma_clip_mask])

            # Now calculate the RMS 
            rms[j] = np.std(comparison_flux[~sigma_clip_mask] / model)

            # Now let's plot
            ax = plt.subplot(gs1[j, 0])
            ax.scatter(phot_table[:, 0][~sigma_clip_mask], comparison_flux[~sigma_clip_mask], alpha = 1, c='k', s= 5)
            ax.plot(phot_table[:, 0][~sigma_clip_mask], model, 'b--')
            ax.set_xticks([]); ax.set_yticks([])
            ax.set_ylabel('{:} pix'.format(apertures[j]), fontsize=7)
            ax2 = ax.twinx()
            ax2.set_ylabel('{:,} ppm'.format(int(1e6*rms[j])))

            # Now verbose and set up for next one
            print('\t\t{:>8} {:>5} has an rms of {:>8,} ppm [{:>3.1f}% exclusion]'.format('Aperture', apertures[j], int(1e6*rms[j]), 100*(np.sum(mask)/mask.shape[0]) ))
    
        # Now plot RMS vs aperture on the LHS of graph (in ppm)
        ax = plt.subplot(gs1[:, 1])
        ax.plot(1e6*rms, apertures, 'b')
        ax.set_ylim(16,0)
        ax.set_xlabel('RMS [ppm]')
        ax.set_ylabel('Aperture size [pixel]')

        best_apertures[i] = round_of_rating(float(apertures[np.argmin(rms)]))
        print('\t\tBest aperture for action: {:} pixels'.format(best_apertures[i]))
        ax.axhline(best_apertures[i]-1, c='b', ls='--')
        plt.suptitle('Action {:}'.format(actions[i]))
        fig1.tight_layout(rect=[0, 0.03, 1, 0.95])
        fig1.savefig('Action_{:}_aperture_rms.png'.format(actions[i]))
        plt.close(fig1)

    


    # Now we have the best apertures for each action, let's take a median
    print('\nSummary of apertures:')
    for i in range(len(actions)) : print('\taction {:>8} -> {:} pixel aperture'.format(actions[i], best_apertures[i]))
    best_aperture = round_of_rating(np.median(best_apertures.astype(float))) 
    best_aperture_string=''     
    for i in range(len(apertures)) : 
        if best_aperture == np.array(apertures).astype(float)[i] : best_aperture_string = apertures[i]      # <-- The string best aperture       
    print('\tMedian : {:} pixel aperture'.format(best_aperture_string))

    # Now plot the apertures VS action ID
    fig1 = plt.figure(figsize=(15,5))
    plt.scatter(range(len(best_apertures)), best_apertures, c='k', s=10)
    plt.axhline(best_aperture, ls='--', c='b')
    plt.xlabel('Action')
    plt.ylabel('Best aperture size [pix]')
    plt.gca().set_xticklabels(actions)
    fig1.tight_layout(rect=[0, 0.03, 1, 0.95])
    fig1.savefig('Actions_aperture_summary.png')
    plt.close()

    
    ###########################################################
    # Section 3
    # Now we have the best aperture, we can start to extract
    # photometry
    ############################################################    
    print('\nPlotting diagnositics...')
    master_phot_table = []
    for i in range(actions.shape[0]):
        print('\tProcessing action {:}'.format(actions[i]))
        # First, we need to find which diff_ref_region with aperture belogs to
        for j in range(len(diff_ref_regions)):
            if (i >= diff_ref_regions[j][0]) and (i< diff_ref_regions[j][1]) : 
                diff_ref_regions_idx = j 
                break

        # Load the photometry table
        phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[i], best_aperture_string))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32

        # Now get axis needed going forward 
        time = phot_table[:, 0] # JD
        BJD = phot_table[:, 1]
        HJD = phot_table[:, 2]

        time_axis = Time(phot_table[:, 0], format='jd')
        target_raw_flux = phot_table[:, -5]
        if np.median(target_raw_flux) < 0 : 
            print('The target flux is below 0')
            print(target_raw_flux)
            print('Exiting to avoid spurious results')
            exit()
        target_sky = phot_table[:, -3]

        comp_star_summed_flux = phot_table[:, (8+2)::7].T[:-1][:10][bad_star_mask[diff_ref_regions_idx]].T.sum(axis=1) # This should match the time axis 
        target_cop_flux = target_raw_flux / comp_star_summed_flux 
        
        target_X = phot_table[:, -7]
        target_Y = phot_table[:, -6]
        max_flux = phot_table[:, -1]

        comp_Xs = phot_table[:, (8)::7].T[:-1][:10][bad_star_mask[diff_ref_regions_idx]].T
        comp_ys = phot_table[:, (8+1)::7].T[:-1][:10][bad_star_mask[diff_ref_regions_idx]].T
        print(comp_Xs.shape)
        
        # Now append to the master phot table 
        if not bad_night[i]:
            if i==0:
                master_phot_table = np.array([time,BJD, HJD, target_raw_flux,target_cop_flux, target_X, target_Y, phot_table[:, 4], phot_table[:, 5],target_sky,max_flux ]).T
            else:
                master_phot_table = np.vstack((master_phot_table, np.array([time,BJD, HJD,target_raw_flux,target_cop_flux, target_X, target_Y, phot_table[:, 4], phot_table[:, 5],target_sky,max_flux ]).T)) 
        else:
            print('\t\tNot appending action {:}'.format(actions[i]))

        # now do summary plot of action 
        fig, (ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8) = plt.subplots(nrows = 8, ncols=1, figsize = (7,18), sharex = True)

        # Plot the photometry on ax1
        ax1.plot_date(time_axis.plot_date, target_cop_flux,c='k', label = 'raw', markersize=1) 
        time_bin, flux_bin, err_bin = lc_bin(phot_table[:, 0], target_cop_flux, 10./24/60) 
        time_axis_ = Time(time_bin, format='jd')
        ax1.plot_date(time_axis_.plot_date, flux_bin, 'ro-', label = '10 minute bin', markersize=5, alpha=0.7) 
        median = np.median(target_cop_flux)
        std = np.std(target_cop_flux)
        try:
            ax1.set_ylim(median - 4*std, median + 4*std)
        except : pass
        ax1.legend(prop={'size': 7})
        ax1.set_ylabel('Target/Ref')


        # plot the X position on ax 2 
        mask = target_X > 0
        ax2.plot_date(time_axis.plot_date[mask], target_X[mask],c='k', markersize=1)
        ax2.set_ylabel('X position')
        ax2.get_yaxis().get_major_formatter().set_useOffset(False)


        # plot the Y position on ax 3 
        mask = target_Y > 0
        ax3.plot_date(time_axis.plot_date[mask], target_Y[mask],c='k', markersize=1)
        ax3.set_ylabel('Y position')
        ax3.get_yaxis().get_major_formatter().set_useOffset(False)

        # plot airmass on ax4 
        ax4.plot_date(time_axis.plot_date, phot_table[:, 4], fmt='k')
        ax4.set_ylabel('Airmass')

        # Plot raw flux on ax5 
        ax5.plot_date(time_axis.plot_date, target_raw_flux, fmt='k', markersize=1)
        ax5.set_ylabel('Target flux')
        median = np.median(target_raw_flux)
        std = np.std(target_raw_flux)
        try:
            ax5.set_ylim(median - 4*std, median + 4*std)
        except : pass

        # Plot FWHM on ax6 
        ax6.plot_date(time_axis.plot_date, phot_table[:, 5], fmt='k', markersize=1)
        ax6.set_ylabel('FWHM [pix]')
        ax6.set_xlabel('Time [UTC]') 

        # Plot sky
        ax7.plot_date(time_axis.plot_date, target_sky, fmt='k', markersize=1)
        ax7.set_ylabel('Target sky counts [pix]')
        ax7.set_xlabel('Time [UTC]') 

        # Plot maxflux pix
        ax8.plot_date(time_axis.plot_date, max_flux, fmt='k', markersize=1)
        ax8.set_ylabel('Max count in aperture')
        ax8.set_xlabel('Time [UTC]')

        # Tidy up and save
        plt.suptitle('Action {:}'.format(actions[i]))
        fig.align_ylabels([ax1, ax2, ax3, ax4, ax5, ax6, ax7,ax8])
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.savefig('Action_{:}_summary.png'.format(actions[i]))
        plt.close()
    
    
    ###########################################################
    # Section 3
    # Detrending. 
    # This needs to be relatively light, so we will only detrend 
    # with: X, Y, airmass
    ############################################################
    print('\nDetrending with X, Y & airmass using PCA...')

    # First, we need to sigma clip the photometry ro avoid outliers/dropouts 
    median = np.median( master_phot_table[:,4])
    std = np.std( master_phot_table[:,4])
    mask = (master_phot_table[:,4] < median - 5*std) | (master_phot_table[:,4] > median + 5*std) 
    mask = ~mask

    A = master_phot_table[:, (5,6)][mask] # look at X, Y, airmass (tried FWHm and sky but there shit)
    try:
        linalg_coeffs = np.linalg.lstsq(A, master_phot_table[:,4][mask])[0] # detrend with mag
        for i in range(len(linalg_coeffs)) : print('\tCoefficient {:} = {:}'.format(i+1, linalg_coeffs[i]))
        trend = (A*linalg_coeffs).T.sum(axis=0) 
        trend = np.interp(master_phot_table[:,0], master_phot_table[:,0][mask], trend) # now interp back 
    except: trend = np.median(master_phot_table[:,4][mask])

    master_phot_table = np.vstack((master_phot_table.T, trend )).T                           # add trend
    master_phot_table = np.vstack((master_phot_table.T, master_phot_table[:,4] /trend )).T   # add detrended flux
    master_phot_table = np.vstack((master_phot_table.T, -2.5*np.log10(master_phot_table[:,-1]) )).T                    # add mad
    # time,BJD, HJD, target_raw_flux,target_cop_flux, X, Y, airmass, FWHM, target_sky, max_flux,  trend, target_cop_flux_detrended, mag

    # Now save master photometry
    print('\nSaving photometry as master_photometry.fits')
    names = ['JD','BJD', 'HJD', 'Target_raw_flux', 'Target_over_comp_flux', 'X', 'Y', 'Airmass', 'FWHM', 'Target_sky', 'Max flux in aperture', 'Trend', 'Target_over_comp_flux_detrended', 'Mag_detrended']
    t = Table(master_phot_table, names=names )
    t.add_column(Column(-2.5*np.log10(t['Target_over_comp_flux']), name = 'Mag'))
    t.write('master_photometry.fits', overwrite=True)
    print('Done')
    

    '''
    ###########################################################
    # Section 1
    # Find best aperture
    ############################################################
    best_apertures = np.empty(actions.shape[0])
    number_of_coparison_stars = 0
    for i in range(actions.shape[0]):
        print('\nProcessing action {:}'.format(actions[i]))

        # Find apertures
        print('\tFinding best aperture')
        print('\t---------------------')

        apertures = [i.split('.')[1][4:] +'.' +  i.split('.')[2] for i in glob.glob('{:}.phot*'.format(actions[i]))]
        apertures.sort(key=float)

        # Now set up the figure to plot 
        fig1 = plt.figure(constrained_layout=False, figsize=(5,20))
        gs1 = gridspec.GridSpec(len(apertures), 2, figure=fig1)
        ax_count = 0 
        rms = np.empty(len(apertures))

        for j in range(len(apertures)):

            # Load the photometry table
            phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[i], apertures[j]))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32 

            # Modify the JD 
            phot_table[:, 0] = phot_table[:, 0] - int(np.min(phot_table[:, 0]))

            # 7 admin rows : JD-MID BJD-TDB-MID HJD-MID EXPTIME AIRMASS FWHM AGERRX AGERRY
            # 7 for the target :  X Y FLUX FLUXERR SKY SKYERR MAXPIX
            # 7 for the comps  :  X Y FLUX FLUXERR SKY SKYERR MAXPIX
            number_of_coparison_stars = (phot_table.shape[1] - (7+7)) // 7

            # Get the flux of each comparison star
            median = np.median(phot_table[:, (8+7+2)::7].sum(axis=1))
            std = np.std(phot_table[:, (8+7+2)::7].sum(axis=1))
            mask = (phot_table[:, (8+7+2)::7].sum(axis=1) < median - 10*std) | (phot_table[:, (8+7+2)::7].sum(axis=1) > median + 10*std)

            # Before we fit a line, we want to do a bin of the nights lightcurve and look at the variance
            # lets do a 10-minute bin
            time_bin, mag_bin, mag_bin_err = lc_bin(phot_table[:, 0], phot_table[:, (8+7+2)::7].sum(axis=1), 10./24./60)
            if (time_bin.shape[0] > 0): # incase the time axis is too short
                if np.max(mag_bin_err) > 1000:
                    noisy_data_times = time_bin[np.where(mag_bin_err > 1000)]
                    if noisy_data_times.shape[0] == 0 :
                        noisy_data_start = noisy_data_times[0] - (10./24./60)
                        noisy_data_end   = noisy_data_times[0] + (10./24./60)
                    else:
                        noisy_data_start = np.min(noisy_data_times) - (10./24./60)
                        noisy_data_end   = np.max(noisy_data_times) + (10./24./60)                    
                    mask = mask | ((phot_table[:, 0] > noisy_data_start) & (phot_table[:, 0] < noisy_data_end))

            # Now lets fit a line
            if phot_table[:, 0][~mask].shape[0] == 0 : mask = ~mask # in case all is masked, don't use it
            weights = phot_table[:, (8+7+2)::7].sum(axis=1)[~mask]**-2*np.sum(phot_table[:, (8+7+2)::7].sum(axis=1)[~mask]**-2)
            model = np.poly1d(np.polyfit(phot_table[:, 0][~mask], phot_table[:, (8+7+2)::7].sum(axis=1)[~mask], 2, w = weights ))(phot_table[:, 0][~mask])

            # Now calculate the RMS 
            rms[j] = np.std(phot_table[:, (8+7+2)::7].sum(axis=1)[~mask] / model)

            # Now let's plot
            ax = plt.subplot(gs1[ax_count, 0])
            ax.scatter(phot_table[:, 0][~mask], phot_table[:, (8+7+2)::7].sum(axis=1)[~mask], alpha = 0.04, c='k', s= 5)
            ax.plot(phot_table[:, 0][~mask], model, 'b--')
            ax.set_xticks([]); ax.set_yticks([])
            ax.set_ylabel('{:} pix'.format(apertures[j]), fontsize=7)

            # Now verbose and set up for next one
            print('\t{:>8} {:>5} has an rms of {:>8,} ppm [{:>3.1f}% exclusion]'.format('Aperture', apertures[j], int(1e6*rms[j]), 100*(np.sum(mask)/mask.shape[0]) ))
            ax_count += 1

        ax = plt.subplot(gs1[:, 1])
        ax.plot(1e6*rms, apertures, 'b')
        ax.set_ylim(16,0)
        ax.set_xlabel('RMS [ppm]')
        ax.set_ylabel('Aperture size [pixel]')

        best_apertures[i] = apertures[np.argmin(rms)]

        ax.axhline(float(best_apertures[i])-1, c='b', ls='--')
        plt.suptitle('Action {:}'.format(actions[i]))
        fig1.tight_layout(rect=[0, 0.03, 1, 0.95])
        fig1.savefig('Action_{:}_aperture_rms.png'.format(actions[i]))
        plt.close(fig1)

    print('\nSummary of apertures:')
    for i in range(len(actions)) : print('\taction {:>8} -> {:} pixel aperture'.format(actions[i], best_apertures[i]))
    best_aperture = round_of_rating(np.median(best_apertures.astype(float)))      
    for i in range(len(apertures)) : 
        if best_aperture == np.array(apertures).astype(float)[i] : best_aperture_string = apertures[i]      # <-- The string best aperture

    print('\tMedian aperture is {:}'.format(best_aperture_string))
    fig1 = plt.figure(figsize=(15,5))
    plt.scatter(range(len(best_apertures)), best_apertures, c='k', s=10)
    plt.axhline(best_aperture, ls='--', c='b')
    plt.xlabel('Action')
    plt.ylabel('Best aperture size [pix]')
    plt.gca().set_xticklabels(actions)
    fig1.tight_layout(rect=[0, 0.03, 1, 0.95])
    fig1.savefig('Actions_aperture_summary.png')
    plt.close()



    ###########################################################
    # Section 2
    # Now we have the best aperture, we need to check stars
    # for variability. 
    ############################################################
    print('\nFinding bad comparison stars...')

    master_phot_table = []
    bad_star_mask = np.zeros((actions.shape[0], number_of_coparison_stars))

    for i in range(actions.shape[0]):
        print('\tProcessing action {:}'.format(actions[i]))

        # Load the photometry table
        phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[i], best_aperture_string))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32

        # Now set up the figure to plot 
        fig1, ax = plt.subplots(nrows = number_of_coparison_stars, ncols=1, figsize=(5,50))


        
        for j in range(number_of_coparison_stars):
            median = np.median(phot_table[:, 8+7+2 + j*7])
            std = np.std(phot_table[:, 8+7+2 + j*7])

            if (np.min(phot_table[:, 8+7+2 + j*7] ) < 10) or (np.std(phot_table[:, 8+7+2 + j*7] ) < 1) : bad_star_mask[i,j] = 1 # catch wierd cases with low photometry (faint)


            mask = (phot_table[:, 8+7+2 + j*7] < median - 5*std) | (phot_table[:, 8+7+2 + j*7] > median + 5*std)
            ax[j].scatter(phot_table[:, 0][~mask], phot_table[:, 8+7+2 + j*7][~mask], alpha = 0.1, c='k', s= 5)
            ax[j].set_xticks([]); ax[j].set_yticks([])
            ax[j].set_ylabel(str(j+1), rotation=0) #

            if bad_star_mask[i,j] == 1: # These are systematicall shit nights
                ax[j].set_facecolor('xkcd:salmon')
                ax[j].set_facecolor((1.0, 0.47, 0.42)) 
        
        plt.tight_layout()
        plt.savefig('Action_{:}_comparison_star_summary.png'.format(actions[i]))
        plt.close()

    # Now convert badstarmask into bool 
    bad_star_mask = bad_star_mask.sum(axis=0).astype(np.bool)


    ###########################################################
    # Section 2.5
    # Now we know whic hones are variable, let's make the lightcurve
    # array 
    ############################################################
    print('\nPlotting diagnositics...')
    for i in range(actions.shape[0]):
        print('\tProcessing action {:}'.format(actions[i]))

        # Load the photometry table
        phot_table = np.genfromtxt('{:}.phot{:}'.format(actions[i], best_aperture_string))[:,1:] # we can skip the filename since it reads as Nan becuase it's a U32

        # Now get axis needed going forward 
        time = phot_table[:, 0] # JD
        BJD = phot_table[:, 1]
        HJD = phot_table[:, 1]

        time_axis = Time(phot_table[:, 0], format='jd')
        target_raw_flux = phot_table[:, (8+2)]
        target_sky = phot_table[:, (8+4)]
        comp_star_summed_flux = phot_table[:, (8+7+2)::7].T[~bad_star_mask].T.sum(axis=1) # This should match the time axis 
        target_cop_flux = target_raw_flux / comp_star_summed_flux 
        target_X = phot_table[:, (8+0)]
        target_Y = phot_table[:, (8+1)]

        # Now append to the master phot table 
        if i==0:
            master_phot_table = np.array([time,BJD, HJD, target_raw_flux,target_cop_flux, target_X, target_Y, phot_table[:, 4], phot_table[:, 5],target_sky ]).T
        else:
            master_phot_table = np.vstack((master_phot_table, np.array([time,BJD, HJD,target_raw_flux,target_cop_flux, target_X, target_Y, phot_table[:, 4], phot_table[:, 5],target_sky ]).T)) 

        # now do summary plot of action 
        fig, (ax1, ax2, ax3, ax4, ax5, ax6, ax7) = plt.subplots(nrows = 7, ncols=1, figsize = (7,15), sharex = True)

        # Plot the photometry on ax1
        ax1.plot_date(time_axis.plot_date, target_cop_flux,c='k', label = 'raw', markersize=1) 
        time_bin, flux_bin, err_bin = lc_bin(phot_table[:, 0], target_cop_flux, 10./24/60) 
        time_axis_ = Time(time_bin, format='jd')
        ax1.plot_date(time_axis_.plot_date, flux_bin, 'ro-', label = '10 minute bin', markersize=2, alpha = 0.4) 
        median = np.median(target_cop_flux)
        std = np.std(target_cop_flux)
        try:
            ax1.set_ylim(median - 4*std, median + 4*std)
        except : pass
        ax1.legend(prop={'size': 7})
        ax1.set_ylabel('Target/Ref')


        # plot the X position on ax 2 
        mask = target_X > 0
        ax2.plot_date(time_axis.plot_date[mask], target_X[mask],c='k', markersize=1)
        ax2.set_ylabel('X position')
        ax2.get_yaxis().get_major_formatter().set_useOffset(False)


        # plot the Y position on ax 3 
        mask = target_Y > 0
        ax3.plot_date(time_axis.plot_date[mask], target_Y[mask],c='k', markersize=1)
        ax3.set_ylabel('Y position')
        ax3.get_yaxis().get_major_formatter().set_useOffset(False)

        # plot airmass on ax4 
        ax4.plot_date(time_axis.plot_date, phot_table[:, 4], fmt='k')
        ax4.set_ylabel('Airmass')

        # Plot raw flux on ax5 
        ax5.plot_date(time_axis.plot_date, target_raw_flux, fmt='k', markersize=1)
        ax5.set_ylabel('Target flux')
        median = np.median(target_raw_flux)
        std = np.std(target_raw_flux)
        try:
            ax5.set_ylim(median - 4*std, median + 4*std)
        except : pass

        # Plot FWHM on ax6 
        ax6.plot_date(time_axis.plot_date, phot_table[:, 5], fmt='k', markersize=1)
        ax6.set_ylabel('FWHM [pix]')
        ax6.set_xlabel('Time [UTC]') 

        # Plot sky
        ax7.plot_date(time_axis.plot_date, target_sky, fmt='k', markersize=1)
        ax7.set_ylabel('Target sky counts [pix]')
        ax7.set_xlabel('Time [UTC]') 

        # Tidy up and save
        plt.suptitle('Action {:}'.format(actions[i]))
        fig.align_ylabels([ax1, ax2, ax3, ax4, ax5, ax6])
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        plt.savefig('Action_{:}_summary.png'.format(actions[i]))
        plt.close()


    ###########################################################
    # Section 3
    # Detrending. 
    # This needs to be relatively light, so we will only detrend 
    # with: X, Y, airmass
    ############################################################
    print('\nDetrending with X, Y & airmass using PCA...')

    A = master_phot_table[:, (5,6,7)] # look at X, Y, airmass (tried FWHm and sky but there shit)
    linalg_coeffs = np.linalg.lstsq(A, master_phot_table[:,4])[0] # detrend with mag
    for i in range(len(linalg_coeffs)) : print('\tCoefficient {:} = {:}'.format(i+1, linalg_coeffs[i]))
    trend = (A*linalg_coeffs).T.sum(axis=0) 
    master_phot_table = np.vstack((master_phot_table.T, (A*linalg_coeffs).T.sum(axis=0) )).T                           # add trend
    master_phot_table = np.vstack((master_phot_table.T, master_phot_table[:,4] /(A*linalg_coeffs).T.sum(axis=0) )).T   # add detrended flux
    master_phot_table = np.vstack((master_phot_table.T, -2.5*np.log10(master_phot_table[:,11]) )).T                    # add mad
    # time,BJD, HJD, target_raw_flux,target_cop_flux, X, Y, airmass, FWHM, target_sky, trend, target_cop_flux_detrended, mag

    # Now save master photometry
    print('\nSaving photometry as master_photometry.fits')
    names = ['JD','BJD', 'HJD', 'Target_raw_flux', 'Target_over_comp_flux', 'X', 'Y', 'Airmass', 'FWHM', 'Target_sky', 'Trend', 'Target_over_comp_flux_detrended', 'Mag_detrended']
    t = Table(master_phot_table, names=names )
    t.write('master_photometry.fits', overwrite=True)
    print('Done')
    '''