#!python
''' #!/home/u1870241/anaconda3/bin/python3 '''
import matplotlib.pyplot as plt 
import os,sys 
import matplotlib.gridspec as gridspec
import argparse 
import numpy as np
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
import os.path
from bruce.binarystar.lc import _lc , kernel_lc, sum_reduce, lc
import numba 
from scipy.signal import find_peaks
from lightkurve.lightcurve import TessLightCurve
import time as time_pack
from astropy.stats import sigma_clip
from numpy import mean
import subprocess 
import emcee, corner

description = '''monofind'''

# Argument parser
parser = argparse.ArgumentParser('predict', description=description)
'''
parser.add_argument('-f', 
                    '--dec',
                    help='The Dec in deg.', 
                    default=-12.35998970810, type=float)      

parser.add_argument('--complete', action="store_true", default=False, help="Only complete nights")


parser.add_argument('-g', 
                '--observatory',
                help='The Observatory.',
                default='Paranal') 
'''
parser.add_argument('--trial', action="store_true", default=False, help="Plot each night")
parser.add_argument('--flatten', action="store_true", default=False, help="Flatten the LC")
parser.add_argument('--remove_dropouts', action="store_true", default=False, help="Plot each night")



parser.add_argument("filename",
                    help='The filename.')

parser.add_argument('-g', 
                '--mask',
                help='The mask used to exclude bad regions. A two-column text file specifying the start and end of bad points.',
                default='None') 

parser.add_argument('-i', 
                '--window_length',
                help='The flattening window.',
                default=11, type=int)

parser.add_argument('-j', 
                '--saveplace',
                help='Where to save',
                default='/ngts/scratch/monofind/')

def lc_sort(time, mag, mag_err):
    sort = sorted(zip(time, mag, mag_err))
    time = np.array([i[0] for i in sort])
    mag = np.array([i[1] for i in sort])
    mag_err = np.array([i[2] for i in sort])
    return time, mag, mag_err


def lc_resample(time, mag, mag_err, width, weighted_mean):
    time_new = np.arange(np.min(time) - width/2, np.max(time) + width/2, 0.5/24) 
    mag = np.interp(time_new, time, mag, left = weighted_mean, right = weighted_mean)
    mag_err = np.interp(time_new, time, mag_err, left = mag_err[0], right = mag_err[0])
    return time_new, mag, mag_err 

@numba.njit(parallel=True)
def main_func(time, mag, mag_err, weighted_mean, model_period, b, spots, chi_ref,cube):
    for i in numba.prange(radius_1s.shape[0]):
        for j in numba.prange(ks.shape[0]):
            for k in numba.prange(time.shape[0]):
                incl = np.arccos(b*radius_1s[i])

                cube[i,j,k, 0] =_lc(time, mag, mag_err, 0.0, weighted_mean,
                                    time[k], model_period,
                                    radius_1s[i], ks[j] ,
                                    fs=0., fc=0., 
                                    q=0., albedo=0.,
                                    alpha_doppler=0., K1=0.,
                                    spots = spots, omega_1=1., 
                                    incl = np.pi/2.,
                                    ld_law_1=2, ldc_1_1=0.8, ldc_1_2=0.8, gdc_1=0.4,
                                    SBR=0., light_3=0.,
                                    E_tol=1e-5,
                                    loglike_switch=1 ) - chi_ref


                
                cube[i,j,k, 1] =_lc(time, mag, mag_err, 0.0, weighted_mean,
                                    time[k], model_period,
                                    radius_1s[i], ks[j] ,
                                    fs=0., fc=0., 
                                    q=0., albedo=0.,
                                    alpha_doppler=0., K1=0.,
                                    spots = spots, omega_1=1., 
                                    incl = np.pi/2.,
                                    ld_law_1=0, ldc_1_1=0.8, ldc_1_2=0.8, gdc_1=0.4,
                                    SBR=0., light_3=0.,
                                    E_tol=1e-5,
                                    loglike_switch=1 ) - chi_ref

                cube[i,j,k, 2] =_lc(time, mag, mag_err, 0.0, weighted_mean,
                                    time[k], model_period,
                                    radius_1s[i], ks[j] ,
                                    fs=0., fc=0., 
                                    q=0., albedo=0.,
                                    alpha_doppler=0., K1=0.,
                                    spots = spots, omega_1=1., 
                                    incl = incl,
                                    ld_law_1=2, ldc_1_1=0.8, ldc_1_2=0.8, gdc_1=0.4,
                                    SBR=0., light_3=0.,
                                    E_tol=1e-5,
                                    loglike_switch=1 ) - chi_ref



def get_lc(time, modelx, modely, t_zero, weighted_mean):
    return np.interp(time, modelx + t_zero, modely, left=weighted_mean, right = weighted_mean) 

def get_lc_loglike(time, mag, mag_err, modelx, modely, t_zero, weighted_mean,chi_ref):
    wt = 1.0/(mag_err**2)
    model = get_lc(time, modelx, modely, t_zero, weighted_mean)
    return -0.5*np.sum((mag-model)**2*wt - np.log(wt)) - chi_ref


def transit_duration(period, radius_1, k, b, incl):
    return period*np.arcsin(radius_1*np.sqrt((1 + k)**2 - b**2)/np.sin(incl))/np.pi



def log_prob(theta, time, mag, mag_err, model_period, model=False):
    t_zero, radius_1, k, b, zp = theta

    if (t_zero < time[0]) or (t_zero > time[-1]) : return -np.inf 
    if (radius_1 < 0.) or (radius_1 > 0.8) : return -np.inf 
    if (k < 0.) or (k > 0.8) : return -np.inf 
    if (b < 0) or (b > 1 + k) : return -np.inf 
    if (zp < -0.1) or (zp > 0.1) : return -np.inf 
    incl = 180*np.arccos(b*radius_1)/np.pi 
    wt = 1./mag_err**2
    if model : return zp - 2.5*np.log10(lc(time, t_zero=t_zero, period=model_period, radius_1=radius_1, k=k, incl=incl, zp=zp))
    else : return lc(time, mag, mag_err, J=0, t_zero=t_zero, period=model_period, radius_1=radius_1, k=k, incl=incl, zp=zp)





if __name__=='__main__':
    # Parse the arguments
    args = parser.parse_args()
    savepath = args.saveplace
    lockfile = args.saveplace + '.lock'

    # load the data 
    try    : time, mag, mag_err = np.loadtxt(args.filename).T 
    except : time, mag, mag_err, f, f_err = np.loadtxt(args.filename).T 

    # mask out bad data points 
    mask = ~(np.isnan(mag) | np.isinf(mag)) 

    # re-load the data 
    try    : time, mag, mag_err = np.loadtxt(args.filename)[mask].T 
    except : time, mag, mag_err, f, f_err = np.loadtxt(args.filename)[mask].T 
    
    # Sort the time axis 
    time, mag, mag_err = lc_sort(time, mag, mag_err)
    mag_err = np.ones(time.shape[0])*1e-3

    # Mask bad data 
    if args.mask is not 'None':
        mask_array = np.loadtxt(args.mask)

        mask_master = np.zeros(time.shape[0], dtype = np.bool)

        for mask in mask_array : 
            mask_master = mask_master + ((time > mask[0]) & (time < mask[1]))

        time = time[~mask_master]
        mag = mag[~mask_master]
        mag_err = mag_err[~mask_master] 


    # if flatten, flatten 
    if args.flatten:
        flux = 10**(-0.4*mag)
        s = TessLightCurve(time, flux, 3000*1e-6*np.ones(time.shape[0]))
        #s = s.flatten(window_length=args.window_length) 
        s = s.flatten() 

        time, mag = s.time, -2.5*np.log10(s.flux)

    if args.remove_dropouts:
        drop_out_thresh = 0.05
        for i in range(1, time.shape[0]-1):
            d1 = abs(mag[i] - mag[i-1])
            d2 = abs(mag[i] - mag[i+1])
            if ((d1 > drop_out_thresh) and (d2 > drop_out_thresh)) : mag[i] = np.random.normal(0,np.std(mag))

    # Get the weighted mean 
    weighted_mean = np.median(mag) #, weights = mag_err, axis=0)

    # Now we need to resample the time axis
    time, mag, mag_err = lc_resample(time, mag, mag_err, 16/24, weighted_mean)

    # Now to minimise damage, we should re-mask and set the masked regions to mag=0
    if args.mask is not 'None':
        mask_array = np.loadtxt(args.mask)

        mask_master = np.zeros(time.shape[0], dtype = np.bool)

        for mask in mask_array : 
            mask_master = mask_master + ((time > mask[0]) & (time < mask[1]))

        mag[mask_master] = np.random.normal(0,np.std(mag), mag[mask_master].shape[0])

    # Now get reference chi 
    chi_ref = get_lc_loglike(time, mag, mag_err,time, np.ones(time.shape[0])*weighted_mean, 0, weighted_mean,0.)

    if args.trial : 
        f = plt.figure(figsize=(15,5))
        plt.scatter(time, mag, c='k', s=10)
        plt.axhline(weighted_mean)
        plt.title('Loglike ref = ' + str(chi_ref))
        plt.gca().invert_yaxis()
        plt.xlabel('Time')
        plt.ylabel('Mag')
        plt.show()
        exit()

    # Get the time span
    model_period = 10.

    # Search for transit widths between 1 and 16 hours
    transit_durations = np.arange(1/24, 16/24 + 0.5/24, 0.5/24)
    radius_1s = np.pi*transit_durations/model_period
    radius_1 = np.linspace(0.01,0.5,30)

    # Now search for a variatey of transit depths from 0.1 mmag to 3 mmag 
    mag_depths = np.arange(0.1e-3,3e-3 + 0.1e-3, 0.1e-3)
    ks = np.sqrt(mag_depths)
    ks = np.linspace(np.sqrt(np.std(mag)*2), 0.5, 30)
    mag_depths = ks**2 

    # Now get the transit durations 
    Tdur = np.zeros((radius_1s.shape[0], ks.shape[0]))
    for i in range(radius_1s.shape[0]):
        for j in range(ks.shape[0]):
            Tdur[i,j] = transit_duration(model_period, radius_1s[i], ks[j], 0, np.pi/2)
    
    # Now get the transit models
    Npoints_in_transit_model = 100
    Models =  np.zeros((radius_1s.shape[0], ks.shape[0], 2, Npoints_in_transit_model))
    for i in range(radius_1s.shape[0]):
        for j in range(ks.shape[0]):
            Models[i,j,0] = np.linspace(-Tdur[i,j]/2, Tdur[i,j]/2, Npoints_in_transit_model) 
            Models[i,j,1] = weighted_mean-2.5*np.log10(lc(Models[i,j,0], period = model_period, radius_1=radius_1s[i], k=ks[j]))
            


    Cube = np.zeros((radius_1s.shape[0], ks.shape[0], time.shape[0]))
    for i in range(radius_1s.shape[0]):
        for j in range(ks.shape[0]):
            for k in range(time.shape[0]):
                Cube[i,j,k] = get_lc_loglike(time, mag, mag_err, Models[i,j,0], Models[i,j,1], time[k], weighted_mean, chi_ref)


    # Now cut for a better plot
    best = np.unravel_index(np.max(Cube, axis=2).argmax(), np.max(Cube, axis=2).shape)
    '''
    fig1 = plt.figure()
    plt.imshow(np.max(Cube, axis=2), interpolation='bilinear', aspect='auto', origin='lower')
    levels = np.linspace(0, np.max(Cube[:,:]) ,10)
    plt.contour(np.max(Cube, axis=2), color='k', )

    # get best index 
    plt.axhline(best[0])
    plt.axvline(best[1])

    xticks = np.arange(0, mag_depths.shape[0], 2)
    xtick_labels = np.interp(xticks, np.linspace(0,mag_depths.shape[0]+1,mag_depths.shape[0]), mag_depths*1e3)
    xtick_labels = ['{:.2f}'.format(i) for i in xtick_labels]
    plt.xticks(xticks, xtick_labels, rotation = 45)
    plt.xlabel('$Depth$ [mmag]')


    yticks = np.arange(0, transit_durations.shape[0], 2)
    ytick_labels = np.interp(yticks, np.linspace(0,transit_durations.shape[0]+1,transit_durations.shape[0]), transit_durations*24)
    ytick_labels = ['{:.2f}'.format(i) for i in ytick_labels]
    plt.yticks(yticks, ytick_labels, rotation = 45)
    plt.ylabel('$T_{dur}$ [hrs]')
    plt.gcf().subplots_adjust(bottom=0.15)
    '''
    median = np.median(Cube[best])
    std = np.std(Cube[best])
    Cube= (Cube - median)

    height = 10*np.std( Cube[best][(Cube[best] > np.percentile(Cube[best],10)) & (Cube[best] < np.percentile(Cube[best], 90))]  )
    peaks, _ = find_peaks(Cube[best], height=height, distance = 24)
    time_ = np.linspace(time[0], time[-1], 10000)

    #t_zero, radius_1, k, b, zp, incl 
    best_fitting_pars = []
    print(len(peaks))
    if len(peaks) > 0:
        for i in range(peaks.shape[0]):
            # t_zero, radius_1, k, b, zp = theta
            current_best = np.unravel_index(Cube[:,:,peaks[i]].argmax(), Cube[:,:,peaks[i]].shape)
            radius_1 = radius_1s[current_best[0]]
            k = ks[current_best[1]]
            width = transit_durations[current_best[0]]
            theta = [time[peaks[i]], radius_1, k, 0.2, 0.0 ]

            ndim = len(theta)
            nwalkers = 4*ndim
            p0 = theta + 1e-8*np.random.randn(nwalkers, ndim)

            # Now create the mask 
            mask = (time < (time[peaks[i]] - 4*width)) | (time > (time[peaks[i]] + 4*width))
            mask = ~mask

            # Create the emcee ensemble
            sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, args=[time[mask], mag[mask], mag_err[mask], model_period])
            sampler.run_mcmc(p0, 3000, progress=False)
            flat_samples = sampler.get_chain(discard=2000, flat=True)
            flat_logs = sampler.get_log_prob(discard=2000, flat=True)
            best_sample = flat_samples[np.argmax(flat_logs)].tolist()
            incl = 180*np.arccos(best_sample[1]*best_sample[3])/np.pi
            width = transit_duration(model_period, best_sample[1], best_sample[2], best_sample[3], np.pi*incl/180.)
            model = log_prob(best_sample, time, mag, mag_err, model_period, model=True)

            OOT_mask = (time[mask] < (time[peaks[i]] - width/2)) | (time[mask] > (time[peaks[i]] + width/2))
            OOT_STD = np.std(mag[mask][OOT_mask])

            #print(np.max(flat_logs)/len(time[mask]))
            #print(np.max(flat_logs)/len(time[mask]), (np.max(flat_logs)/len(time[mask]) < -10), (width < 2.5/24), ((np.max(model) - np.min(model)) < 6e-3), ((np.max(model) - np.min(model)) < 3*OOT_STD), np.isnan(OOT_STD) )
            #if (np.max(flat_logs)/len(time[mask]) < -150) or (width < 2.5/24) or ((np.max(model) - np.min(model)) < 6e-3) or ((np.max(model) - np.min(model)) < 3*OOT_STD) or np.isnan(OOT_STD) : best_fitting_pars.append([])
            if (width < 2.5/24) or ((np.max(model) - np.min(model)) < 6e-3) : best_fitting_pars.append([])

            else:
                best_sample.append(incl)
                best_sample.append(OOT_STD)
                best_sample.append(np.max(flat_logs)/len(time[mask]))
                best_fitting_pars.append(best_sample)
    else:
        exit()

    
    number_of_real_peaks = 0
    for i in best_fitting_pars :
        if len(i) >0 : number_of_real_peaks += 1


    try :     mask_array = np.loadtxt(args.mask)
    except : mask_array = []

    if number_of_real_peaks > 0:
        fig2,axs = plt.subplots(nrows = number_of_real_peaks + 1, ncols = 1, figsize=(10,number_of_real_peaks*5),)

        axs[0].plot(time, Cube[best], 'k')
        axs[0].plot(time[peaks], Cube[best][peaks], "x")
        axs[0].set_xlabel('Time')
        axs[0].set_ylabel(r'$\mathcal{L} - \mathcal{L}_{\rm wm}$')
        axs[0].set_title('Number of peaks: {:} [{:} vetted]'.format(len(peaks), number_of_real_peaks))
        axs[0].axhline(height, c='b', ls='--')
        axs[0].axhline(0,ls='--', color='k')
        axs[0].grid()
        for i in range(peaks.shape[0]) : axs[0].text(time[peaks][i]+0.25, Cube[best][peaks][i], '{:}'.format(i+1), fontsize=15)

        for mask in mask_array : axs[0].fill_between(mask, [axs[0].get_ylim()[0], axs[0].get_ylim()[0]], [axs[0].get_ylim()[1], axs[0].get_ylim()[1]], facecolor='green', alpha = 0.3)

        ax_count=1
        peak_count = 1
        for best_sample in best_fitting_pars:
            if len(best_sample) ==0 :
                peak_count += 1
                continue 
            else:
                ax = axs[ax_count]
                for mask in mask_array : ax.fill_between(mask, [ax.get_ylim()[0], ax.get_ylim()[0]], [ax.get_ylim()[1], ax.get_ylim()[1]], facecolor='green', alpha = 0.3)

                ax.set_ylabel('Mag [mmag]')
                incl = 180*np.arccos(best_sample[1]*best_sample[3])/np.pi
                width = transit_duration(model_period, best_sample[1], best_sample[2], best_sample[3], np.pi*best_sample[5]/180.)
                mask = (time < (best_sample[0] - 4*width)) | (time > (best_sample[0] + 4*width))
                mask = ~mask
                model = best_sample[4] - 2.5*np.log10(lc( np.linspace(time[mask][0], time[mask][-1],1000) , t_zero = best_sample[0], period=model_period, radius_1 = best_sample[1], k = best_sample[2], incl=best_sample[5]))
                ax.scatter(time[mask], mag[mask]*1e3, c='k', s=10)
                ax.plot(np.linspace(time[mask][0], time[mask][-1],1000), model*1e3, 'r')
                ax.axhline(6, ls='--', c='k', alpha = 0.5)
                ax.fill_between(time[mask],best_sample[-2]*np.ones(time[mask].shape[0]), -best_sample[-2]*np.ones(time[mask].shape[0]), facecolor='red', alpha=0.25 )

                depth = np.max(model)*1e3 - np.min(model)*1e3
                ax.set_ylim(2*depth, -1*depth)
                ax.set_title('Loglike_reduce : {:.2f} [PEAK {:}]'.format(np.max(flat_logs) / len(time[mask]), peak_count))
                ax.set_xlim(time[mask][0], time[mask][-1])
                ax.grid()
                ax_count += 1
                peak_count += 1

                # Now save to the text file 
                while True:
                    if not os.path.isfile(lockfile):
                        flock = open(lockfile, 'w+')
                        flock.close()

                        if not os.path.isfile(savepath+'monofind_results.dat'):
                            f = open(savepath+'monofind_results.dat', "w+")
                            f.write('filename,peak, t_cen, radius_1, k, b, width, depth, SNR, loglike_peak\n')
                        else:
                            f = open(savepath+'monofind_results.dat', "a")
                        f.write('{:}, {:}, {:}, {:}, {:}, {:}, {:}, {:}, {:}, {:}\n'.format(args.filename, peak_count-1, best_sample[0], best_sample[1], best_sample[2], best_sample[3], width, depth, best_sample[-2], best_sample[-1] ))
                        f.close()
                        os.system('rm {:}'.format(lockfile))
                        break
                    else : 
                        #print('Im sleeping')
                        time_pack.sleep(1)
        ax.set_xlabel('Time [BTJD]')
        fig2.align_ylabels(axs)
        fig2.savefig(savepath+os.path.basename(os.path.splitext(args.filename)[0])+'_transits_found.png')
        plt.close(fig2) 
    else : exit()