#!python

import numpy as np 
import matplotlib.pyplot as plt 
from bruce import lc 
from tqdm import tqdm
import emcee, corner, sys, os, numpy as np
import argparse 
from multiprocessing import Pool
from scipy.stats import chisquare, sem
import matplotlib.cm as cm
from scipy.signal import find_peaks
np.warnings.filterwarnings('ignore')
from astropy.stats import LombScargle

# LC bin 
def lc_bin(time, flux, bin_width):
        '''
        Function to bin the data into bins of a given width. time and bin_width 
        must have the same units
        '''
        
        edges = np.arange(np.min(time), np.max(time), bin_width)
        dig = np.digitize(time, edges)
        time_binned = (edges[1:] + edges[:-1]) / 2
        flux_binned = np.array([np.nan if len(flux[dig == i]) == 0 else flux[dig == i].mean() for i in range(1, len(edges))])
        err_binned = np.array([np.nan if len(flux[dig == i]) == 0 else sem(flux[dig == i]) for i in range(1, len(edges))])
        time_bin = time_binned[~np.isnan(err_binned)]
        err_bin = err_binned[~np.isnan(err_binned)]
        flux_bin = flux_binned[~np.isnan(err_binned)]   
        
        return time_bin, flux_bin, err_bin




# Welcom messages
welcome_message = '''---------------------------------------------------
-                   tls V.1                   -
-             samuel.gill@wariwck.ac.uk           -
---------------------------------------------------'''

description = '''A program to search for transit events in ground-based photometry.'''

parser = argparse.ArgumentParser('tls', description=description)

parser.add_argument("filename",
                    help='The filename of the ground-based photometry.')

parser.add_argument('-a', 
                    '--bin',
                    help='The bin width from which to bin the lightcurve, in minutes [default=None].', 
                    default=0.0, type=float)

parser.add_argument('-b', 
                    '--period',
                    help='The orbital period in arbritraty time units consisting with the input file.',
                    default=4.511, type=float)  

parser.add_argument('-c', 
                    '--radius_1',
                    help='The radius of star 1 in units of the semi-major axis, a.',
                    default=0.13, type=float)  

parser.add_argument('-d', 
                    '--k',
                    help='The ratio of the radii of star 2 and star 1 (R2/R1).',
                    default=0.09619, type=float)  
    
parser.add_argument('-e', 
                    '--b',
                    help='The impact parameter of the orbit (incl = arccos(radius_1*b).',
                    default=0., type=float)  

parser.add_argument('-f', 
                    '--threshold',
                    help='The difference in chi-squared minimum for a peak. ',
                    default=860, type=float)     

parser.add_argument('-g', 
                    '--distance',
                    help='The index difference between peaks.',
                    default=100, type=int)   

parser.add_argument('-i', 
                    '--light_3',
                    help='The third light.',
                    default=0.0, type=float)                  

if __name__ == '__main__':
    # print welcome message
    print(welcome_message)

    # Parse arguments 
    args = parser.parse_args()

    # Now load the lightcurve 
    time, mag, mag_err = np.loadtxt(args.filename).T 

    # Now check for nan and mask
    time, mag, mag_err = np.loadtxt(args.filename)[~np.isnan(mag)].T

    # Now sort 
    sort = sorted(zip(time, mag, mag_err))
    time = np.array([i[0] for i in sort])
    mag = np.array([i[1] for i in sort])
    mag_err = np.array([i[2] for i in sort])


    #time = time - 2450000
    print('Successfully read {:} lines from {:}'.format(len(time), args.filename)) 

    if args.bin > 0 : 
        time, mag, mag_err = lc_bin(time, mag, args.bin/24./60.)
        print('\treduced to {:} lines with {:}-minute binning'.format(len(time), args.bin))

    # Now let's plot the photometry
    fig = plt.figure(figsize=(15,4))
    plt.scatter(time, mag, c='k', s=5)   
    plt.gca().invert_yaxis()
    plt.xlabel('Time [d]')
    plt.ylabel('Mag')
    fig.tight_layout()
    #fig.savefig(os.path.basename(os.path.splitext(args.filename)[0])+'photometry.png')
    plt.close(fig) 

    # Lets calculate the weighted mean and plot it
    weighted_mean = np.average(mag, weights = mag_err, axis=0)
    weighted_mean = np.median(mag)
    
    Chi_WM = np.sum( ((mag - weighted_mean)**2) / (mag_err**2) )/ len(time)
    print('Weighted mean of photometry = {:.6f} mag.'.format(weighted_mean))
    print('\tChi_WM = ', Chi_WM)

    fig = plt.figure(figsize=(15,4))
    plt.scatter(time, mag, c='k', s=5)   
    plt.gca().invert_yaxis()
    plt.xlabel('Time [d]')
    plt.ylabel('Mag')
    fig.tight_layout()
    plt.plot(time, weighted_mean*np.ones(time.shape[0]))

    fig.savefig(os.path.basename(os.path.splitext(args.filename)[0])+'photometry_weighted_mean.png')
    plt.close(fig) 


    # Now we need to make the template
    # we need
    # period, radius1, k , b
    time_template = np.linspace(-0.3,0.3, 10000)
    mag_template = weighted_mean-2.5*np.log10(lc(time_template, period=args.period, radius_1=args.radius_1, k=args.k, incl = 180*np.arccos(args.b*args.radius_1)/np.pi, ldc_1_1=0.8, ldc_1_2 = 0.8, ld_law_1=2, light_3 = args.light_3))

    width_mask = (mag_template > weighted_mean + 0.0005)
    half_width = np.min(time_template[width_mask])
    depth_template = np.max(mag_template) - np.min(mag_template)
    print('Transit width = {:.2f} hrs'.format(-2*half_width*24))

    fig = plt.figure(figsize=(15,4))
    plt.plot(time_template, mag_template, 'k') 
    plt.axvline(half_width, c='k', ls='--')
    plt.axvline(-half_width, c='k', ls='--')
    plt.gca().invert_yaxis()
    plt.xlabel('Time [d]')
    plt.ylabel('Mag')
    fig.tight_layout()

    fig.savefig(os.path.basename(os.path.splitext(args.filename)[0])+'template.png')
    plt.close(fig) 

    # Now we need to re-sample the time axis to catch ingress and egress at the start/end of each night
    # Split into nights 
    time_resampled = np.array([]) 
    time_resampled_ = [] 

    gradiant=np.gradient(time) 
    threshold = 0.1 

    night = []

    '''
    for i in range(gradiant.shape[0]):
        if (gradiant[i] < threshold):
            night.append(time[i])


        elif len(night) > 0: 
            time_resampled_.append(night) 
            night = [] 
            
    for i in range(len(time_resampled_)):
        time_resampled = np.concatenate(( time_resampled, np.arange(time_resampled_[i][0] + half_width, time_resampled_[i][-1] - half_width, 1*(time_resampled_[1][1] - time_resampled_[1][0]) ) ))
    '''

    time_resampled = np.linspace(time[0], time[-1], 100000)


    
    mag_resampled = np.interp(time_resampled, time, mag, left = weighted_mean, right=weighted_mean)
    mag_err_resampled = np.interp(time_resampled, time, mag_err, left = np.mean(mag_err), right=np.mean(mag_err))
    

    fig = plt.figure(figsize=(15,4))
    plt.scatter(time_resampled, mag_resampled, c='k', s=10)
    plt.gca().invert_yaxis()
    plt.xlabel('Time [d]')
    plt.ylabel('Mag')
    fig.tight_layout() 
    #fig.savefig(os.path.basename(os.path.splitext(args.filename)[0])+'re-sampled.png')
    plt.close(fig) 


    # Nw do the main calculation 
    dM = np.empty(time_resampled.shape[0])
    for i in tqdm(range(time_resampled.shape[0])):
        dM[i] = Chi_WM - np.sum( ((mag - np.interp(time, time_template + time_resampled[i], mag_template, left=weighted_mean, right=weighted_mean ))**2) / (mag_err**2) )/ len(time)

    fig, (ax1,ax2) = plt.subplots(nrows=2, ncols=1, figsize=(15,4), sharex=True) 
    ax1.scatter(time, mag, c='k', s=10)
    ax1.invert_yaxis() 
    ax1.set_ylabel('Mag')

    ax2.plot(time_resampled, dM)
    #np.save('dM', np.array([time_resampled, dM]))
    ax2.set_xlabel('Time [d]')
    ax2.set_ylabel(r'$\chi_{WM}^2 - \chi^2$')
    #ax2.axhline(args.threshold, ls='--', c='k')
    ax2.set_ylim(0, None)
    fig.tight_layout() 
    #fig.savefig('chi_metric.png')


    # Now find peaks 
    peaks, _ = find_peaks(dM, distance=100, height = 5)
    peaks, _ = find_peaks(dM, distance=100, height = 1)

    #mask = dM[peaks] > args.threshold 
    #peaks = peaks[mask]
    print('Number of peaks : {:}'.format(len(peaks)))
    ax2.plot(time_resampled[peaks], dM[peaks], "x")
    ax2.set_ylim(0, None)

    fig.savefig(os.path.basename(os.path.splitext(args.filename)[0])+'chi_metric_peaks.png')
    plt.close() 

    response = 'y' # input('Happy? (y/n)')

    if response.lower()=='y':
        plt.close() 

        tb, mb, mbe = lc_bin(time, mag, 30./24./60.)

        #fig1 = plt.figure()
        #frequency, power = LombScargle(time, mag ).autopower()

        #plt.plot(1/frequency, power)
        #plt.show() 

        ndim = int(np.ceil(len(peaks)/4))
        f = open(os.path.basename(os.path.splitext(args.filename)[0]) + 'tls_results.dat', 'w')
        if len(peaks) > 4 : fig_master, axs = plt.subplots(nrows = ndim , ncols =4 , figsize=(10,10*ndim) )
        for i in range(len(peaks))[:]:
            fig = plt.figure(figsize=(15,4))
            plt.scatter(time, mag, c='k', s=10)
            plt.scatter(tb, mb, c='r', s=10)

            plt.xlim(time_resampled[peaks[i]]-3*(half_width), time_resampled[peaks[i]]+3*(half_width) )
            plt.plot(time_template + time_resampled[peaks[i]], mag_template, 'r')
            #np.save('model_epoch_{:}.npy'.format(i+1), np.array([time_template + time_resampled[peaks[i]], mag_template]))
            
            plt.ylim(weighted_mean-2*depth_template, weighted_mean + 2*depth_template)
            plt.title('Epoch {:}'.format(i+1))
            plt.gca().invert_yaxis()
            fig.savefig(os.path.basename(os.path.splitext(args.filename)[0])+'TLS_EPOCH_{:}.png'.format(i+1))
            plt.close(fig)

            print('Epoch {:} : {:} (diff {:})'.format(i+1, time_resampled[peaks[i]], time_resampled[peaks[i]] - time_resampled[peaks[max(i-1, 0)]]))
            f.write('{:} {:} {:}\n'.format(i+1, time_resampled[peaks[i]], time_resampled[peaks[i]] - time_resampled[peaks[max(i-1, 0)]]))

            if len(peaks) > 4 :
                axs[int(np.floor(i/4))][int(i % 4)].scatter(time, mag, c='k', s=10)
                axs[int(np.floor(i/4))][int(i % 4)].set_xlim(time_resampled[peaks[i]]-3*(half_width), time_resampled[peaks[i]]+3*(half_width) )
                axs[int(np.floor(i/4))][int(i % 4)].plot(time_template + time_resampled[peaks[i]], mag_template, 'r')
                axs[int(np.floor(i/4))][int(i % 4)].set_title('Night {:}'.format(i+1))
                axs[int(np.floor(i/4))][int(i % 4)].set_ylim(weighted_mean-2*depth_template, weighted_mean + 2*depth_template)
                axs[int(np.floor(i/4))][int(i % 4)].invert_yaxis() 
                axs[int(np.floor(i/4))][int(i % 4)].get_xaxis().get_major_formatter().set_useOffset(False)
        f.close()
        if len(peaks) > 4 :
            fig_master.savefig(os.path.basename(os.path.splitext(args.filename)[0])+'_Transits_found.png')
            fig_master.tight_layout()
            plt.close(fig_master)
