#!python

"""
STARTING OVER WITH A COMPLETE REWRITE

PROCESS:
    Take in which object to analyse
    Go find its folder and phot output

    For each action
        For each aperture size
            For each comparison star
                work out the RMS/quality
                flag any potentially bad stars

            Make the light curve
            Analyse the RMS (OOT if transit) in data
        Make a final light curve from the best aperture and comparisons
"""
import sys
import logging
import matplotlib.pyplot as plt
import numpy as np
import glob
from astropy.table import Table, Column
import matplotlib.gridspec as gridspec
from scipy.stats import sem 
from astropy.time import Time
import warnings
from tqdm import tqdm 

warnings.simplefilter(action='ignore', category=FutureWarning)
np.warnings.filterwarnings('ignore')

plt.rcParams.update({'font.size': 7}) 

def lc_bin(time, flux, bin_width):
        '''
        Function to bin the data into bins of a given width. time and bin_width 
        must have the same units
        '''
        
        edges = np.arange(np.min(time), np.max(time), bin_width)
        dig = np.digitize(time, edges)
        time_binned = (edges[1:] + edges[:-1]) / 2
        flux_binned = np.array([np.nan if len(flux[dig == i]) == 0 else flux[dig == i].mean() for i in range(1, len(edges))])
        err_binned = np.array([np.nan if len(flux[dig == i]) == 0 else sem(flux[dig == i]) for i in range(1, len(edges))])
        time_bin = time_binned[~np.isnan(err_binned)]
        err_bin = err_binned[~np.isnan(err_binned)]
        flux_bin = flux_binned[~np.isnan(err_binned)]   
        
        return time_bin, flux_bin, err_bin


def get_nights(time, mag, threshold):
    time_, mag_ = [],[]
    night_time, night_mag= [],[]
    grad = np.gradient(time)
    for i in tqdm(range(len(time))):
        if grad[i] < threshold:
            night_time.append(time[i])
            night_mag.append(mag[i])
        elif len(night_time) > 1:
            time_.append(np.array(night_time))
            mag_.append(np.array(night_mag))
            night_time, night_mag= [],[]
    return time_, mag_

if __name__=="__main__":
    t = Table.read('master_photometry.fits')   

    f, (ax1,ax2) = plt.subplots(nrows=2,ncols=1, figsize=(15,10))


    time, mag = get_nights(t['JD'], t['Mag'], 0.05) 
    ax1.scatter(time[0]-time[0], mag[0], c='k', s=10, alpha = 0.4)
    for i in range(1,len(time)-1):
        ax1.scatter(time[i]-time[i][0] + i*0.4, mag[i], c='k', s=10, alpha = 0.4)
    std = [np.std(i) for i in mag]
    std = np.median(std) 
    median = [np.median(i) for i in mag]
    median = np.median(median)     
    if not (np.isnan(median-5*std) or np.isinf(median-5*std)) : ax1.set_ylim(median-5*std, median+5*std)

    time_bin, flux_bin, err_bin = lc_bin(t['JD'], t['Mag'], 30/24/60 )

    time, mag = get_nights(time_bin, flux_bin, 0.3) 
    ax1.scatter(time[0]-time[0], mag[0], c='k', s=10, alpha = 0.4)
    for i in range(1,len(time)-1):
        ax1.scatter(time[i]-time[i][0] + i*0.4, mag[i], c='r', s=10, alpha = 1)






    std_bin = np.std(flux_bin)
    ax1.set_ylabel('Raw Mag\nRAW STD = {:.2f} mmag [{:.1f}%]\n30-min STD = {:.2f} mmag [{:.1f}%]'.format(std*1e3, 100-100*10**(-0.4*std), std_bin*1e3, 100-100*10**(-0.4*std_bin)))
    ax1.get_xaxis().set_ticks([])

    time, mag = get_nights(t['JD'], t['Mag_detrended'], 0.05) 
    np.savetxt('TLS_RAW.dat', np.array([np.array(t['JD']).tolist(), np.array(t['Mag_detrended']).tolist(), (np.ones(len(t))*1e-3).tolist()]).T)

    ax2.scatter(time[0]-time[0], mag[0], c='k', s=10, alpha = 0.4)
    for i in range(1,len(time)-1):
        ax2.scatter(time[i]-time[i][0] + i*0.4, mag[i], c='k', s=10, alpha = 0.4)
    std = [np.std(i) for i in mag]
    std = np.median(std) 
    median = [np.median(i) for i in mag]
    median = np.median(median)     
    if not (np.isnan(median-5*std) or np.isinf(median-5*std)) : ax2.set_ylim(median-5*std, median+5*std)
    time_bin, flux_bin, err_bin = lc_bin(t['JD'], t['Mag_detrended'], 30/24/60 )
    np.savetxt('TLS_30.dat', np.array([time_bin.tolist(), flux_bin.tolist(), err_bin.tolist()]).T)

    time, mag = get_nights(time_bin, flux_bin, 0.3) 
    ax1.scatter(time[0]-time[0], mag[0], c='k', s=10, alpha = 0.4)
    for i in range(1,len(time)-1):
        ax2.scatter(time[i]-time[i][0] + i*0.4, mag[i], c='r', s=10, alpha = 1)
    std_bin = np.std(flux_bin)
    ax2.set_ylabel('Raw Mag\nRAW STD = {:.2f} mmag [{:.1f}%]\n30-min STD = {:.2f} mmag [{:.1f}%]'.format(std*1e3, 100-100*10**(-0.4*std), std_bin*1e3, 100-100*10**(-0.4*std_bin)))

    ax2.get_xaxis().set_ticks([])



    ax1.set_title('{:.0f} stacked actions'.format(len(time)))
    plt.savefig('stacked_actions.png')
    plt.close()