#!python
import numpy as np 
import argparse 
from photutils import aperture_photometry, CircularAperture, CircularAnnulus
from astropy.visualization import simple_norm
from astropy.io import fits
import matplotlib.pyplot as plt


import numpy as np
from astropy.io import fits
from scipy.signal import fftconvolve
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from numba import autojit

def Gauss(x, a, x0, sigma):
    return a * np.exp(-(x - x0)**2 / (2 * sigma**2))

@autojit(nopython=True)
def count_flux(image, x,y,r1,r2,r3):
    flux1 = 0.
    flux2 = 0.
    flux3 = 0.
    tmp1 = 0.0
    r1 = r1*r1
    r2 = r2*r2
    r3 = r3*r3

    for i in range(image.shape[1]):
        for j in range(image.shape[0]):
            tmp1 = i*i + j*j
            if (tmp1 < r1):
                flux1 = flux1 + image[j,i]
            if (tmp1 < r2):
                flux2 = flux2 + image[j,i]
            if (tmp1 < r3):
                flux3 = flux3 + image[j,i]

    return flux1,flux2,flux3





def find_offset(im1, im2, xlim, ylim , method = 'gaussian'):
    ############################
    # Note original shapes
    ############################
    orig_shape1 = np.array(im1.shape)
    orig_shape2 = np.array(im2.shape)

    im1 = im1.astype(np.float64)
    im2 = im2.astype(np.float64)

    #####################
    # Preliminary checks
    #####################
    if xlim[0] > xlim[1]:
        raise ValueError('The first x limit cannot exceed the second.')
    if ylim[0] > ylim[1]:
        raise ValueError('The first y limit cannot exceed the second.')

    ####################
    # First trim images
    ####################
    im1 = im1[xlim[0]:xlim[1], ylim[0]:ylim[1] ]
    im2 = im2[xlim[0]:xlim[1], ylim[0]:ylim[1] ]

    ##############################################################
    # get rid of the averages, otherwise the results are not good
    ##############################################################
    im1 -= np.mean(im1)
    im2 -= np.mean(im2)

    ##########################################################################
    # calculate the correlation image; note the flipping of onw of the images
    ##########################################################################
    corr_img =  fftconvolve(im1, im2[::-1,::-1], mode='same')


    if method == 'best_pixel':
        ##############################################################################
        # Now unravel the offset by finding the peak
        # Note that this is relative to ratio of the original image and the cut image
        ###############################################################################
        best_corr = np.array(np.unravel_index(np.argmax(corr_img), corr_img.shape)) # the best coordinates
        dx, dy = (best_corr[::-1] - np.array(corr_img.shape)[::-1]/2) * np.array([-1,-1])
        dxe, dye = 0.5,  0.5
        #return (best_corr[::-1] - np.array(corr_img.shape)[::-1]/2) * np.array([-1,-1])
        return dx,dxe,dy,dye

    elif method == 'gaussian_fit':
        ######################################################
        # Get the mean of the CCF image in y and x direction
        ######################################################
        corr_img_y , corr_img_x = corr_img.mean(axis=0), corr_img.mean(axis=1)

        corr_img_x  =corr_img_x - np.min(corr_img_x)
        corr_img_x = corr_img_x/ corr_img_x.max()

        corr_img_y  =corr_img_y - np.min(corr_img_y)
        corr_img_y = corr_img_y / corr_img_y.max()
        

        #######################################
        # Now get the pixel it socrresponds to
        #########################################
        corr_img_yy , corr_img_xx = np.arange(corr_img.shape[1]) - np.array(corr_img.shape[1])/2  ,  np.arange(corr_img.shape[0]) - np.array(corr_img.shape[0])/2 


	
        #####################
        # do y first
        #####################  
        n = corr_img_xx.shape[0]
        n_low, n_high = np.int(np.floor(0.35*n)), np.int(np.floor(0.7*n))
        corr_img_x = corr_img_x[n_low:n_high]
        corr_img_xx = corr_img_xx[n_low:n_high]

        mean = np.sum(corr_img_xx* corr_img_x) / np.sum(corr_img_x)
        sigma = np.sqrt(np.sum(corr_img_x * (corr_img_xx - mean)**2) / np.sum(corr_img_x))
        try:
            popt,pcov = curve_fit(Gauss, corr_img_xx, corr_img_x, p0=[max(corr_img_x), 0, 1])
            perr = np.sqrt(np.diag(pcov))
            dx, dxe = -popt[1], perr[1]
        except:
            dx,dxe = 0,99

        '''
        plt.close()
        plt.plot(corr_img_xx,Gauss(corr_img_xx,*popt),'r:',label='fit')
        plt.plot(corr_img_xx,corr_img_x,'b')
        plt.show()
        plt.sleep(1)
        '''



        #####################
        # do x next
        ##################### 
        n = corr_img_yy.shape[0]
        n_low, n_high = np.int(np.floor(0.35*n)), np.int(np.floor(0.7*n))
        corr_img_y = corr_img_y[n_low:n_high]
        corr_img_yy = corr_img_yy[n_low:n_high]
    
        mean = np.sum(corr_img_yy* corr_img_y) / np.sum(corr_img_y)
        sigma = np.sqrt(np.sum(corr_img_y * (corr_img_y - mean)**2) / np.sum(corr_img_y))
        try:
            popt,pcov = curve_fit(Gauss, corr_img_yy, corr_img_y, p0=[max(corr_img_y), 0, 1])
            perr = np.sqrt(np.diag(pcov))
            dy, dye = -popt[1], perr[1]
        except:
            dy,dye = 0,99

        '''
        plt.close()
        plt.plot(corr_img_yy,Gauss(corr_img_yy,*popt),'r:',label='fit')
        plt.plot(corr_img_yy,corr_img_y,'b')
        plt.show()
        plt.sleep(1)
        '''
        ##################
        # Ooutput result
        ##################
        #print('dx: {:.3f} +/- {:.3f}, dy: {:.3f} +/- {:.3f}'.format(dx,dxe,dy,dye))

        return dy,dye,dx,dxe

    else:
        msg='''
Method choice not understood.

Available choices are:

1) best_pixel

2) gaussian_fit
'''
        raise ValueError(msg)

           
        
    

        






# Welcom messages
welcome_message = '''---------------------------------------------------
-                   quickphot V.1                 -
-             samuel.gill@wariwck.ac.uk           -
---------------------------------------------------'''

description = '''A program to search for transit events in ground-based photometry.'''

parser = argparse.ArgumentParser('tls', description=description)

parser.add_argument('-a', 
                    '--reference_image_idx',
                    help='The reference image idx.', 
                    type=int, default=0)

parser.add_argument('-aa', 
                    '--datacube',
                    help='The datacube. ', 
                    type=str, default='data.fits')

parser.add_argument('-b', '--target', 
                    help='The target star in pixel position e.g. --target x y', 
                    nargs='+',
                    default=[0,0], type = float)

parser.add_argument('-c', '--comparison', 
                    help='The comparison star in pixel position e.g. --comparison x y', 
                    nargs='+',
                    default=[0,0], type = float)

parser.add_argument('-d', '--check', 
                    help='The check star in pixel position e.g. --comparison x y', 
                    nargs='+',
                    default=[0,0], type = float) 
        
parser.add_argument('-e', '--r1', 
                    help='The inner radius of the aperture.', 
                    type = float, default=5.) 

parser.add_argument('-f', '--skyin', 
                    help='The outer radius of the aperture.', 
                    type = float, default=8.) 

parser.add_argument('-g', '--skyout', 
                    help='The sky aperture', 
                    type = float, default=10.) 

parser.add_argument('-i', '--trialxlim', 
                    help='The xlim for the field e.g. --trialxlim 0 10.', 
                    nargs='+',
                    default=[None,None])     


parser.add_argument('-j', '--trialylim', 
                    help='The ylim for the field e.g. --trialylim 0 10.', 
                    nargs='+',
                    default=[None,None])    

parser.add_argument('--image_align', action="store_true", default=False)


parser.add_argument('--trial', action="store_true", default=False)


def p_(x): return np.ones(x.shape[0])

if __name__=="__main__":
    # print welcome message
    print(welcome_message)

    # Parse arguments 
    args = parser.parse_args()

    # Verbose the reference image and coordinates
    print('Data cube: ' ,args.datacube)
    #target = [float(args.target[0]), float(args.target[1])]
    #comparison = [float(args.comparison[0]), float(args.comparison[1])]
    #check = [float(args.check[0]), float(args.check[1])]

    print('\tReference idx : ', args.reference_image_idx)
    print('\tTarget:     x = {:.3f}, y = {:.3f}'.format(*args.target))
    print('\tComparison: x = {:.3f}, y = {:.3f}'.format(*args.comparison))
    print('\tCheck:      x = {:.3f}, y = {:.3f}'.format(*args.check))
    print('---------------------------------------------------')
    
    # Define the perture positions 
    positions = np.array([args.target, args.comparison, args.check])
    positions_reference = np.array([args.target, args.comparison, args.check])

    aperture = CircularAperture(positions, r=args.r1)
    annulus_aperture = CircularAnnulus(positions, r_in=args.skyin, r_out=args.skyout)
    apers = [aperture, annulus_aperture]

    # Load the reference image 
    data = fits.open(args.datacube)[0].data
    reference_image = data[args.reference_image_idx, :, :] 
    name = fits.open(args.datacube)[0].header['OBJECT']
    norm = simple_norm(reference_image, 'sqrt', percent=99)

    # Check for trial 

    plt.imshow(reference_image, norm=norm,origin='lower', aspect='auto')
    aperture.plot(color='white', lw=2)
    annulus_aperture.plot(color='red', lw=2)
    plt.xlabel('X pix')
    plt.ylabel('Y pix')
    plt.text(args.target[0]+10, args.target[1]+10, s='Target')
    plt.text(args.comparison[0]+10, args.comparison[1]+10, s='Comparison')
    plt.text(args.check[0]+10, args.check[1]+10, s='Check')

    if None not in args.trialxlim : 
        plt.xlim([float(i) for i in args.trialxlim])
    if None not in args.trialylim : 
        plt.ylim([float(i) for i in args.trialylim])
    plt.gca().invert_xaxis()
    plt.savefig('{:}_field.png'.format(name))

    if args.trial:
        plt.show() 
        exit()
    plt.close()
    

    flux1, flux2, flux3, X, Y, sky = [],[],[], [],[],[]
    for i in range(data.shape[0]):
        # First, align and modify aperture
        if args.image_align:
            dx,dxe,dy,dye = find_offset(reference_image, data[i, :,:], xlim=[0, reference_image.shape[0]], ylim = [0, reference_image.shape[1]], method='gaussian_fit')
            if (dx < -50) or (dx > 50) : dx = 0. 
            if (dy < -50) or (dy > 50) : dy = 0. 
            #print(dx,dxe, dy, dye)
            positions = np.array([[args.target[0]+ dx, args.target[1]- dy] , [args.comparison[0]+ dx, args.comparison[1]- dy], [args.check[0]+ dx, args.check[1]- dy]])
            aperture = CircularAperture(positions, r=args.r1)
            annulus_aperture = CircularAnnulus(positions, r_in=args.skyin, r_out=args.skyout)
            apers = [aperture, annulus_aperture]

            X.append(dx)
            Y.append(dy)
        else:
            X.append(0.)
            Y.append(0.)

        phot_table = aperture_photometry(data[i, :,:], apers)
        bkg_mean = phot_table['aperture_sum_1'] / annulus_aperture.area
        bkg_sum = bkg_mean * aperture.area
        final_sum = phot_table['aperture_sum_0'] - bkg_sum
        phot_table['residual_aperture_sum'] = final_sum
        phot_table['residual_aperture_sum'].info.format = '%.8g'  # for consistent table output
        flux1.append(phot_table['residual_aperture_sum'][0])
        flux2.append(phot_table['residual_aperture_sum'][1])
        flux3.append(phot_table['residual_aperture_sum'][2])



    fig, axs = plt.subplots(3, 2, figsize=((10,15)))
    axs[0,0].scatter(range(len(flux1)), np.array(flux1)/np.array(flux2), c='k', s=10) 
    try:
        z = np.polyfit(range(len(flux1)), np.array(flux1)/np.array(flux2), 2)
        p = np.poly1d(z)
    except : p = p_
    axs[0,0].plot(range(len(flux1)), p(np.arange(len(flux1))), 'r') 

    axs[0,1].scatter(range(len(flux1)), np.array(flux1)/np.array(flux2) /p(np.arange(len(flux1))) , c='k', s=10)
    axs[0,1].set_title('RMS : {:.0f} ppm'.format(1e6*np.std(np.array(flux1)/np.array(flux2) /p(np.arange(len(flux1))))))
    axs[1,0].scatter(range(len(flux1)), np.array(flux1)/np.array(flux3), c='k', s=10)   
    axs[2,0].scatter(range(len(flux1)), np.array(flux2)/np.array(flux3), c='k', s=10)   

    axs[1,1].scatter(range(len(flux1)), X, c='k', s=10)
    axs[2,1].scatter(range(len(flux1)), Y, c='k', s=10)


    axs[0,0].set_ylabel('target / comparison')
    axs[1,0].set_ylabel('target / check ')
    axs[2,0].set_ylabel('comparison / check')
    axs[1,1].set_ylabel('X [pix]')
    axs[2,1].set_ylabel('Y [pix]')
    axs[2,0].set_xlabel('Frame')
    axs[2,1].set_xlabel('Frame')
    axs[0,0].set_title(name)



    plt.savefig('{:}_photometry.png'.format(name))
    plt.show()






'''
positions = np.transpose((sources['xcentroid'], sources['ycentroid']))  
apertures = CircularAperture(positions, r=4.)  
phot_table = aperture_photometry(image, apertures)  
for col in phot_table.colnames:  
'''