#!/home/sam/anaconda3/bin/python
import numpy as np
import matplotlib.pyplot as plt
import argparse
from bruce.binarystar import lc 
from scipy.stats import sem
from tqdm import tqdm
from multiprocessing import Pool


description = '''A program to search for transit events in ground-based photometry.'''

parser = argparse.ArgumentParser('tls', description=description)

parser.add_argument("filename",
                    help='The filename of the ground-based photometry.')

parser.add_argument('-a', 
                    '--bin',
                    help='The bin width from which to bin the lightcurve, in minutes [default=None].', 
                    default=0.0, type=float)

parser.add_argument('-b', 
                    '--period',
                    help='The orbital period in arbritraty time units consisting with the input file.',
                    default=4.511, type=float)  

parser.add_argument('-c', 
                    '--radius_1',
                    help='The radius of star 1 in units of the semi-major axis, a.',
                    default=0.13, type=float)  

parser.add_argument('-d', 
                    '--k',
                    help='The ratio of the radii of star 2 and star 1 (R2/R1).',
                    default=0.09619, type=float)  
    
parser.add_argument('-e', 
                    '--b',
                    help='The impact parameter of the orbit (incl = arccos(radius_1*b).',
                    default=0., type=float)  

parser.add_argument('-i', 
                    '--light_3',
                    help='The third light.',
                    default=0.0, type=float)         

parser.add_argument('-j', 
                    '--ldc_1',
                    help='The first limb-dakening parameter [power-2]',
                    default=0.8, type=float)     

parser.add_argument('-k', 
                    '--ldc_2',
                    help='The second limb-dakening parameter [power-2]',
                    default=0.8, type=float)   

parser.add_argument('-l', 
                    '--threads',
                    help='Multiprocessing threads',
                    default=1, type=int)   

def lc_bin(time, flux, bin_width):
        '''
        Function to bin the data into bins of a given width. time and bin_width 
        must have the same units
        '''
        
        edges = np.arange(np.min(time), np.max(time), bin_width)
        dig = np.digitize(time, edges)
        time_binned = (edges[1:] + edges[:-1]) / 2
        flux_binned = np.array([np.nan if len(flux[dig == i]) == 0 else flux[dig == i].mean() for i in range(1, len(edges))])
        err_binned = np.array([np.nan if len(flux[dig == i]) == 0 else sem(flux[dig == i]) for i in range(1, len(edges))])
        time_bin = time_binned[~np.isnan(err_binned)]
        err_bin = err_binned[~np.isnan(err_binned)]
        flux_bin = flux_binned[~np.isnan(err_binned)]   
        
        return time_bin, flux_bin, err_bin

def transit_duration(period, radius_1, k, b, incl):
    return period*np.arcsin(radius_1*np.sqrt((1 + k)**2 - b**2)/np.sin(np.pi*incl/180))/np.pi



def fun(zp, t_zero):
    model= zp + lc(t_zero_grid, t_zero =  t_zero, period = args.period, radius_1 = args.radius_1, k = args.k, incl = incl, ldc_1_1 = args.ldc_1, ldc_1_2 = args.ldc_2 )
    return np.sum((model - resampled_flux)**2 / (0.1**2)) / t_zero_grid.shape[0]

def data_stream(a, b):
    for i, av in enumerate(a):
        for j, bv in enumerate(b):
            yield (i, j), (av, bv)

def proxy(args):
    return args[0], fun(*args[1])


if __name__=="__main__":
    # First, parse args 
    args = parser.parse_args()

    # Load datafile
    try:
        time, flux = np.loadtxt(args.filename).T
    except ValueError:
        time, flux, err = np.loadtxt(args.filename).T
    flux_min, flux_max = np.percentile(flux, [3, 98])
    print('Successfully read {:} lines from {:}'.format(len(time), args.filename))
    if args.bin != 0. : 
        time, flux, err = lc_bin(time, flux , args.bin/60/24)
        print('\tBinned down to {:} lines'.format(len(time)))
    #flux = flux/np.median(flux[-int(time.shape[0]*0.1):])
    average_cadence = (time[-1] - time[0])/time.shape[0]
    average_scatter = np.std(flux[:int(time.shape[0]*0.1)])
    print('\tAverage cadence : {:.2f} minutes'.format(average_cadence*24*60))
    print('\tAverage scatter of first 10% : {:.0f}'.format(average_scatter*1e6))


    # Now get the model
    incl = 180*np.arccos(args.radius_1*args.b)/np.pi
    model_width = transit_duration(args.period, args.radius_1, args.k, args.b, incl) 
    model_time = np.linspace(-0.8*model_width, 0.8*model_width, 1000)
    model_flux = lc(model_time, t_zero = 0, period = args.period, radius_1 = args.radius_1, k = args.k, incl=incl, ldc_1_1 = args.ldc_1, ldc_1_2 = args.ldc_2)
    model_depth = np.max(model_flux) - np.min(model_flux)

    #plt.axhline(1 - model_depth)
    #plt.plot(model_time, model_flux)
    #plt.show()

    zp_grid = np.linspace(-0.0005,0.0005, 100)
    t_zero_grid = np.arange(time[0] - 0.5*model_width, time[-1] + 0.5*model_width+average_cadence, average_cadence) 

    # Now re-sample the data
    resampled_flux = np.interp(t_zero_grid, time, flux, left = -99, right =-99)
    resampled_flux[resampled_flux==-99] = np.random.normal(1.0, average_scatter, len(resampled_flux[resampled_flux==-99]) ) # pad the data with *similar* data

    

    chis_reduced = np.zeros((zp_grid.shape[0], t_zero_grid.shape[0]))

    if args.threads<2:
        for i in tqdm(range(zp_grid.shape[0])):
            for j in tqdm(range(t_zero_grid.shape[0])):
                chis_reduced[i,j] = fun(zp_grid[i], t_zero_grid[j])
    else:
        pool = Pool(int(args.threads))
        results = pool.map(proxy, data_stream(zp_grid, t_zero_grid))
        for k,v in results:
            chis_reduced[k] = v
        
    best = np.unravel_index(chis_reduced.argmax(), chis_reduced.shape)
    print('Best transit center is at {:.5f} at offset {:.2f}'.format(t_zero_grid[best[1]], 1+zp_grid[best[0]]))
    plt.scatter(time, flux, c='k', s=10)
    plt.plot(t_zero_grid, zp_grid[best[0]] + lc(t_zero_grid, t_zero =  t_zero_grid[best[1]], period = args.period, radius_1 = args.radius_1, k = args.k, incl = incl, ldc_1_1 = args.ldc_1, ldc_1_2 = args.ldc_2 ), 'r')
    plt.title(r'$\chi^2$ = ' + str(chis_reduced[best]))
    f = plt.figure()
    plt.imshow(chis_reduced, aspect = 'auto')
    plt.show()