#!python
import numpy as np 
import matplotlib.pyplot as plt 
from bruce import lc 
from tqdm import tqdm
import emcee, corner, sys, os, numpy as np
import argparse 
from multiprocessing import Pool
from scipy.stats import chisquare, sem
import matplotlib.cm as cm
from scipy.signal import find_peaks
np.warnings.filterwarnings('ignore')
from astropy.stats import LombScargle

import numba, numba.cuda
from interpolation import interp


# LC bin 
def lc_bin(time, flux, bin_width):
        '''
        Function to bin the data into bins of a given width. time and bin_width 
        must have the same units
        '''
        
        edges = np.arange(np.min(time), np.max(time), bin_width)
        dig = np.digitize(time, edges)
        time_binned = (edges[1:] + edges[:-1]) / 2
        flux_binned = np.array([np.nan if len(flux[dig == i]) == 0 else flux[dig == i].mean() for i in range(1, len(edges))])
        err_binned = np.array([np.nan if len(flux[dig == i]) == 0 else sem(flux[dig == i]) for i in range(1, len(edges))])
        time_bin = time_binned[~np.isnan(err_binned)]
        err_bin = err_binned[~np.isnan(err_binned)]
        flux_bin = flux_binned[~np.isnan(err_binned)]   
        
        return time_bin, flux_bin, err_bin


@numba.jit(parallel=False)
def main_function(time, phase, mag, mag_err, weights, periods, phase_steps, time_template, mag_template, SR, phase_offset, x_bin, x_bin_edges,y_bin, y_bin_count):

    #plt.ion()

    #f, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(10,5))


    
    # First, cycle period 
    for i in numba.prange(periods.shape[0]):
        #ax1.clear()
        #ax2.clear()
        # First make phase
        for j in range(time.shape[0]):
            phase[j] = (time[j] / periods[i] ) %1
        
        # Now bin for speed
        for k in range(y_bin.shape[0]):
            y_bin[k] = 0. 
            y_bin_count[k] = 0

        numba_bin(phase,mag,x_bin_edges, y_bin, y_bin_count)

        #ax1.scatter(phase, mag, c='k', s=10, alpha = 0.006)
        #ax1.scatter(x_bin, y_bin, c='r', s=10)
        #plt.suptitle(str(periods[i]))


        for j in range(phase_steps.shape[0]):

            model = interp(time_template + phase_steps[j], mag_template, x_bin)

            #ln, = ax1.plot(x_bin, model, 'g')

            
            chi = -0.5*((y_bin - model)**2).sum()

            #ax2.scatter(phase_steps[j], chi, c='k') 

            if chi > SR[i]:
                SR[i] = chi 
                phase_offset[i] = phase_steps[j]
                #ax2.axvline(phase_offset[i])

            #plt.show()
            #plt.pause(0.0001) #Note this correction
            #ln.remove()
        printProgressBar(i + 1, periods.shape[0], prefix = 'Progress:', suffix = 'Complete', length = 50)

@numba.njit
def numba_bin(x,y,x_bin_edges, y_bin, y_bin_count):
    for i in range(x.shape[0]): # for each time
        for j in range(x_bin_edges.shape[0]-1): # for each bin
            if (x[i] > x_bin_edges[j]) and (x[i] <= x_bin_edges[j+1]):
                y_bin[j] += y[i]
                y_bin_count[j] += 1
    for j in range(x_bin_edges.shape[0]-1) : 
        if y_bin_count[j] != 0 : y_bin[j] /= y_bin_count[j] # normalise


# Print iterations progress
def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '█'):
    """
    Call in a loop to create terminal progress bar
    @params:
        iteration   - Required  : current iteration (Int)
        total       - Required  : total iterations (Int)
        prefix      - Optional  : prefix string (Str)
        suffix      - Optional  : suffix string (Str)
        decimals    - Optional  : positive number of decimals in percent complete (Int)
        length      - Optional  : character length of bar (Int)
        fill        - Optional  : bar fill character (Str)
    """
    percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total)))
    filledLength = int(length * iteration // total)
    bar = fill * filledLength + '-' * (length - filledLength)
    print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = '\r')
    # Print New Line on Complete
    if iteration == total: 
        print()



@numba.cuda.jit
def main_function_GPU(time, phase, mag, mag_err, weights, periods, phase_steps, time_template, mag_template, SR, phase_offset, x_bin, x_bin_edges,y_bin, y_bin_count):

    # First, cycle period 
    i = numba.cuda.grid(1)

    for j in range(time.shape[0]):
        phase[j] = (time[j] / periods[i] ) %1
    
    # Now bin for speed
    for k in range(y_bin.shape[0]):
        y_bin[k] = 0. 
        y_bin_count[k] = 0

    numba_bin(phase,mag,x_bin_edges, y_bin, y_bin_count)


    for j in range(phase_steps.shape[0]):

        model = interp(time_template + phase_steps[j], mag_template, x_bin)
        
        chi = -0.5*((y_bin - model)**2).sum()


        if chi > SR[i]:
            SR[i] = chi 
            phase_offset[i] = phase_steps[j]




# Welcom messages
welcome_message = '''---------------------------------------------------
-                   M-BLS V.1                     -
-             samuel.gill@wariwck.ac.uk           -
---------------------------------------------------'''

description = '''A program to search for transit events in ground-based photometry using a modified BLS algorithm.'''


parser = argparse.ArgumentParser('tls', description=description)

parser.add_argument("filename",
                    help='The filename of the ground-based photometry.')

parser.add_argument('-a', 
                    '--bin',
                    help='The bin width from which to bin the lightcurve, in minutes [default=None].', 
                    default=0.0, type=float) 

parser.add_argument('-b', 
                    '--period',
                    help='The orbital period in arbritraty time units consisting with the input file.',
                    default=4.511, type=float)  

parser.add_argument('-c', 
                    '--radius_1',
                    help='The radius of star 1 in units of the semi-major axis, a.',
                    default=0.13, type=float)  

parser.add_argument('-d', 
                    '--k',
                    help='The ratio of the radii of star 2 and star 1 (R2/R1).',
                    default=0.09619, type=float)  
    
parser.add_argument('-e', 
                    '--b',
                    help='The impact parameter of the orbit (incl = arccos(radius_1*b).',
                    default=0., type=float)                      

if __name__=='__main__':
    # print welcome message
    print(welcome_message)

    # Parse arguments 
    args = parser.parse_args()

    # Now load the lightcurve 
    try : time, mag, mag_err = np.loadtxt(args.filename).T 
    except: time, mag, mag_err, aaa, aaaaa = np.loadtxt(args.filename).T 
    print('Successfully read {:} lines from {:}'.format(len(time), args.filename)) 

    if args.bin > 0 : 
        time, mag, mag_err = lc_bin(time, mag, args.bin/24./60.)
        print('\treduced to {:} lines with {:}-minute binning'.format(len(time), args.bin))


    time_min = np.min(time)
    time = time - time_min


    ###############################
    # Let's calculate the weights
    ###############################
    weights = mag_err**-2*(np.sum(mag_err**-2))**-1   # It's assumed that weight*mag == 0. In reality its < 1e-3 for NGTS-2b 

    ###################################
    # Now we need to make the template
    # we need
    #  radius1, k , b -> we're in phase so don't need period
    ###################################
    time_template = np.linspace(-0.1,0.1, 10000)
    mag_template = -2.5*np.log10(lc(time_template, period = 1., radius_1=args.radius_1, k=args.k, incl = 180*np.arccos(args.b*args.radius_1)/np.pi, ldc_1_1=0.8, ldc_1_2 = 0.8, ld_law_1=2 ))
    width_mask = (mag_template > 0.000)
    half_width = -np.min(time_template[width_mask])
    full_width = 2*half_width


    depth_template = np.max(mag_template) - np.min(mag_template)
    print('Transit width = {:.2f} hrs'.format(full_width*args.period*24))

    fig = plt.figure(figsize=(15,4))
    plt.plot(time_template, mag_template, 'k') 
    plt.axvline(half_width, c='k', ls='--')
    plt.axvline(-half_width, c='k', ls='--')
    plt.gca().invert_yaxis()
    plt.xlabel('Phase')
    plt.ylabel('Mag')
    plt.xlim(-2*half_width, 2*half_width)
    fig.tight_layout()

    fig.savefig('transit_template.png')
    plt.close(fig) 


    # Make the call 
    periods = np.linspace(300, 400,50000)
    phase_steps = np.linspace(0.1,0.9,100)
    SR = np.ones(periods.shape[0])*-np.inf 
    phase_offset = np.empty_like(SR)
    phase  = np.empty_like(time)

    Nbins = 151
    x_bin_edges = np.linspace(0,1,Nbins)
    x_bin = (x_bin_edges + (x_bin_edges[1] - x_bin_edges[0])/2)[:-1] 
    y_bin = np.zeros(x_bin_edges.shape[0]-1)
    y_bin_count = np.zeros(x_bin_edges.shape[0]-1)




    # CPU
    main_function(time, phase, mag, mag_err, weights, periods, phase_steps, time_template, mag_template, SR, phase_offset, x_bin, x_bin_edges,y_bin, y_bin_count)

    # GPU
    #threads_per_block = 256
    #blocks = int(np.ceil(time.shape[0] / 256))
    #main_function_GPU[blocks, threads_per_block](time, phase, mag, mag_err, weights, periods, phase_steps, time_template, mag_template, SR, phase_offset, x_bin, x_bin_edges,y_bin, y_bin_count)


    f, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(10,10))
    ax1.plot(periods, SR)
    ax1.set_xlabel('Period [d]')
    ax1.set_ylabel(r'$\mathcal{L}$')


    best_period = periods[np.argmax(SR)] 
    best_offset = phase_offset[np.argmax(SR)] 

    ax1.axvline(best_period, c='k', ls='--')
    for i in range(2, 4):
        ax1.axvline(best_period*i, c='b', ls='--')
        ax1.axvline(best_period/i, c='b', ls='--')
    ax1.set_xlim(periods[0], periods[-1])

    print('best period : ', best_period)
    print('Epoch : ', time_min + best_offset*best_period)
    ax2.scatter(time/best_period %1 - best_offset, mag, c='k', s=10, alpha = 0.1)
    ax2.scatter(time/best_period %1 - best_offset-1, mag, c='k', s=10, alpha = 0.1)


    ax2.plot(time_template , mag_template, 'r')
    ax2.set_ylim(-1.5*depth_template,3*depth_template)

    ax2.invert_yaxis() 
    ax2.set_xlabel('Phase')
    ax2.set_ylabel('Mag')
    ax2.set_title('Best period : {:.6} days\nEpoch : {:.5f}'.format(best_period, time_min + best_offset*best_period))
    ax2.set_xlim(-0.1,0.1) 

    plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=None, hspace=0.3)

    plt.savefig('MBLS.png')


    tmp = np.array([periods, SR, phase_offset]).T 
    np.savetxt('MBLS.dat', tmp)
    plt.show()




