#!python

from astropy.table import Table

import numpy as np
from astroquery.mast import Catalogs
from astroquery.mast import Tesscut
from astropy.coordinates import SkyCoord
from astropy.wcs import WCS
from astropy.io import fits
import matplotlib.pyplot as plt
from tqdm import tqdm
import os,sys
import matplotlib.path
import math
from lightkurve.lightcurve import TessLightCurve
import imreg
import astropy.units as u
from astropy.coordinates import SkyCoord
from astroquery.gaia import Gaia

from scipy.signal import savgol_filter

import os, fnmatch
from scipy.interpolate import UnivariateSpline


ramp_masks = [[1421, 1422.7],
			  #[1347, 1349.9],   # Noisy region
			  [1435.97, 1438.71],
			  [1449.9, 1451.8],
			  [1463.34, 1465]]



ramp_masks = [[1383.65, 1384], [1395.5, 1396.6], [1382, 1382.5], [1385.85, 1385.97], [1392.26,1392.44],
			  [1413.06,1413.26], [1413.26, 1413.44],
			  [1421.1, 1421.89],     # S1
			  [1421.9, 1422.86], [1422.5,1422.75],
			  [1437.8, 1438.75 ], 
			  [1450, 1450.22], 
			  [1435.75, 1436.96] , 
			  [1451.5, 1452], 
			  [1463.59, 1464 ],
			  [1478.1, 1478.47],
			  [1517,1518],
			  [1530.25,1531.85],
			  [1534.97,1536.5],
			  [1569.3,1570.75], [1582.69,1584.76], [1596.5, 1599.2], [1610.5, 1612.7], [1624.82, 1626.84], [1653.83, 1655.07]]

noisy_mask = [[1347.08, 1350.14], [1382.53, 1383.69], [1394, 1396], [1410.96,1411.11],[1412.26,1412.32], [1422.58, 1423.56], [1435.91,1436.89]]
#noisy_mask=[]

def find(pattern, path):
    result = []
    for root, dirs, files in os.walk(path):
        for name in files:
            if fnmatch.fnmatch(name, pattern):
                result.append(os.path.join(root, name))
    return result



# Define a function to simplify the plotting command that we do repeatedly.
def plot_cutout(image, ax):
    """
    Plot image and add grid lines.
    """
    ax.imshow(image, origin = 'lower', cmap = 'gray_r', 
           vmax = np.percentile(image, 96),
           vmin = np.percentile(image,5)) # 

    ax.grid(axis = 'both',color = 'white', ls = 'solid')




def aperture_phot(image, aperture):
    """
    Sum-up the pixels that are in the aperture for one image.
    image and aperture are 2D arrays that need to be the same size.
    
    aperture is a boolean array where True means to include the light of those pixels.
    """
    flux = np.sum(image[aperture])

    return flux

def make_lc(flux_data, aperture):
    """
    Apply the 2d aperture array to the and time series of 2D images. 
    Return the photometric series by summing over the pixels that are in the aperture.
    
    Aperture is a boolean array where True means it is in the desired aperture.
    """
    
    flux = np.array(list (map (lambda x: aperture_phot(x, aperture), flux_data) ) )

    return flux


def phaser(time, t_zero, period) : return ((time - t_zero)/period) - np.floor((time - t_zero)/period)


# A function to estimate the windowed scatter in a lightcurve
def estimate_scatter_with_mask(mask, flux):
    f = np.sum(flux[:, mask], axis=-1)
    smooth = savgol_filter(f, 501, polyorder=5)
    return 1e6 * np.sqrt(np.median((f / smooth - 1)**2))

if __name__=='__main__':
	######################################################
	# Step 1 - find the NOI number in the NOI-key.txt file
	######################################################
	#print(len(sys.argv))
	if len(sys.argv)==5:
		ra, dec = float(sys.argv[1]), float(sys.argv[2])
		size = int(sys.argv[3])
	elif len(sys.argv)==4:
		tic_catalog_data = Catalogs.query_object('TIC'+sys.argv[1], radius=.02, catalog="TIC")
		ra=tic_catalog_data['ra'][0]*u.deg
		dec=tic_catalog_data['dec'][0]*u.deg
		size = int(sys.argv[2])


	#########################################
	# Step 2 - Query the object with TESS cut
	#########################################
	coord = SkyCoord(ra, dec, unit = "deg")
	try:
		hdulist = Tesscut.get_cutouts(coord, size )
		numsecs = len(hdulist)
		hdu1 = hdulist[0]
	except: raise ValueError('I failed to ge the TESS cut :(')

	#########################
	# Step 3 - plot the image
	#########################
	fig2 = plt.figure(figsize = (15,5))
	lcax = plt.gca()

	fig = plt.figure(figsize = (5, 5*numsecs))
	gs = fig.add_gridspec(numsecs, 1)



	master_time, master_mag = [],[]

	for i in range(numsecs)[:]:
		hdu1 = hdulist[i]
		image = hdu1[1].data['FLUX'][100]
		image = np.median(hdu1[1].data['FLUX'], axis=0)
		order = np.argsort(image.flatten())[::-1]
		wcs = WCS(hdu1[2].header)

		#for j in range(len(hdu1[1].data['FLUX'])-1) : 
		#	print(imreg.similarity(hdu1[1].data['FLUX'][0], hdu1[1].data['FLUX'][j]))

		ax = fig.add_subplot(gs[i, 0], projection = wcs)
		plot_cutout(image, ax) 
		###############################################################
		# Step 4 - Now we make an aperture based on a contour (argv[3])
		###############################################################

		# Loop over pixels ordered by brightness and add them one-by-one
		# to the aperture
		masks, scatters = [], []
		for j in range(30, 100):
			msk = np.zeros_like(image, dtype=bool)
			msk[np.unravel_index(order[:j], image.shape)] = True
			scatter = estimate_scatter_with_mask(msk, hdu1[1].data['FLUX'])
			masks.append(msk)
			scatters.append(scatter)

		# Choose the aperture that minimizes the scatter
		mask = masks[np.argmin(scatters)]
		for j in range(mask.shape[0]):
			for k in range(mask.shape[1]):
				if mask[k,j]==True:
					if np.hypot(abs(j - mask.shape[1]/2), abs(j - mask.shape[0]/2)) > 3 : 
						mask[k,j] = False



		# Plot the selected aperture
		ax.imshow(image.T, cmap="gray_r")
		ax.imshow(~mask.T, cmap="Reds", alpha=0.3)
		ax.plot(mask.shape[0]/2, mask.shape[1]/2, 'r+')
		if np.sum(mask) < 2 :
			print('BAD MASK')
			plt.show()
			exit()
		#plt.show()
		#exit()

		'''
		for aperture in np.arange(lowest_ap, highest_ap+apstep, apstep):
			cs = plt.contour(image, np.array([aperture]) )
			centerx, centery = image.shape[0]/2., image.shape[1]/2.
			idx = 99
			# Now we have to slect the path that encloses the center
			for j in range(len(cs.collections[0].get_paths() )):
				p = cs.collections[0].get_paths()[j]
				v = p.vertices
				xmin, xmax = np.min(v[:,0]), np.max(v[:,0])
				ymin, ymax = np.min(v[:,1]), np.max(v[:,1])
				if (xmin < centerx) and (xmax > centerx) and (ymin < centery) and (ymax > centery) : idx = j
			if (idx==99):
				print('I could not find a contour which overlapped the target')
				continue
				#plt.show()
				#raise ValueError('No valid contour')
			else :
				print('Best path is ', idx)

				##################################################
				# Step 5 - now create the mas based on the contour
				##################################################
				p = cs.collections[0].get_paths()[idx]
				v = p.vertices	
					
				x_pixel_nos = v[:,0]
				y_pixel_nos = v[:,1]
				temp_list = []
				for a, b in zip(x_pixel_nos, y_pixel_nos):
					temp_list.append([a, b])

				polygon = np.array(temp_list)
				left = np.min(polygon, axis=1)
				right = np.max(polygon, axis=0)
				x = np.arange(math.ceil(left[0]), math.floor(right[0])+1)
				y = np.arange(math.ceil(left[1]), math.floor(right[1])+1)
				x = np.arange(image.shape[0])
				y = np.arange(image.shape[1])
				xv, yv = np.meshgrid(x, y, indexing='xy')
				points = np.hstack((xv.reshape((-1,1)), yv.reshape((-1,1))))

				path = matplotlib.path.Path(polygon)
				mask = path.contains_points(points)
				mask.shape = xv.shape

				# Valid contour, lets get the area in pixels 
				areas.append(np.sum(mask))

				# Get the LC and append it to RMS 
				lc = make_lc(hdu1[1].data['FLUX'], mask)
				rms.append(np.std(lc))
				aps.append(aperture)
		plt.show()
		plt.close() 
		plt.scatter(aps[::-1], np.gradient(rms), c='r')
		ax2 = plt.gca().twinx()
		plt.scatter(aps, np.gradient(areas))
		plt.show() 
		'''
		#exit()

		
		#plt.imshow(mask, alpha = 0.2, cmap='jet', origin='lower')
		ax.set_xlabel('R.A')
		ax.set_ylabel('Dec')


		####################################################################################
		# Step 8 - Calculate the background flux using a mask from the lowest 5 % percentile
		####################################################################################
		bkgAperture = hdu1[1].data['FLUX'][0] < np.percentile(hdu1[1].data['FLUX'][0], 10)
		bkgFlux1 = make_lc(hdu1[1].data['FLUX'], bkgAperture)

		
		######################
		# Step 9 - plot the LC
		######################
		aperture = mask
		flux1 = make_lc(hdu1[1].data['FLUX'], aperture)
		bkgSubFlux = flux1 - (bkgFlux1 * np.sum(aperture) / np.sum(bkgAperture) )
		mask = (hdu1[1].data['QUALITY'] > 0) | (hdu1[1].data['POS_CORR1'] > 0)  | (hdu1[1].data['POS_CORR1'] > 0) 
		time1 = hdu1[1].data['TIME']
		time1 = time1[~mask]
		bkgSubFlux = bkgSubFlux[~mask]
		bkgFlux1 = bkgFlux1[~mask]

		######################## 
		# Now flatten LC
		######################## 

		s = TessLightCurve(time1, bkgSubFlux, 3000*1e-6*np.ones(time1.shape[0]))
		s = s.flatten() 
		#s = s.remove_outliers(sigma=10)
		time, mag = s.time, -2.5*np.log10(s.flux)
		time = time[~np.isinf(mag)]
		mag = mag[~np.isinf(mag)]

		########################
		# Finally, RAMP mask
		#######################
		
		#plt.close()
		for mask in ramp_masks:
			#print(time, mask[0], mask[1])
			mask_ = (time > mask[0]) & (time < mask[1])
			if np.sum(mask_) < 5 : continue
			#f = plt.figure()
			#print(np.sum(~mask_), len(time))
			#plt.scatter(time[mask_], mag[mask_])
			cs = UnivariateSpline(time[mask_], mag[mask_], s=10, k = 3)
			mag[mask_] = mag[mask_] - cs(time[mask_])
			#plt.plot(time[mask_], cs(time[mask_]), 'r')
			#plt.show()

		for mask in noisy_mask:
			#print(time, mask[0], mask[1])
			mask_ = (time > mask[0]) & (time < mask[1])
			time = time[~mask_]
			mag = mag[~mask_]
		'''
		mask = np.zeros(len(time), dtype = np.bool)
		for j in range(len(time)):
			for mask_ in ramp_masks:
				if (time[j] > mask_[0]) and (time[j] < mask_[1]) : mask[j] = 1


		time = time[~mask]
		mag = mag[~mask] 
		'''
		for j in range(len(time)):
			if not np.isnan(mag[j]):
				master_time.append(time[j])
				master_mag.append(mag[j])


	drop_out_thresh = 0.04
	dropcount = 0
	for i in range(1, len(master_time)-1):
		d1 = abs(master_mag[i] - master_mag[i-1])
		d2 = abs(master_mag[i] - master_mag[i+1])
		if ((d1 > drop_out_thresh) and (d2 > drop_out_thresh)) : 
			#lcax.plot(master_time[i], master_mag[i], 'g+')
			#print('old : {:}'.format(master_mag[i]))
			master_mag[i] = np.random.normal(0,np.nanstd(master_mag))
			#print('new : {:}'.format(master_mag[i]))

			dropcount += 1
	#print('Dropcount = ', dropcount)

	lcax.scatter(master_time, master_mag, c='k', s=10)

	tmp = np.array([master_time, master_mag, np.ones(len(master_time))*1e-3]).T 


	lcax.invert_yaxis()
	lcax.set_ylabel('T mag')
	lcax.set_xlabel('BTJD')
	#fig.tight_layout()
	fig2.tight_layout()
	fig.subplots_adjust(left=0.2)



	if len(sys.argv)==4: 
		np.savetxt('{:}.dat'.format(sys.argv[1]), tmp)
		fig.savefig('{:}_aperture.png'.format(sys.argv[1]))
		fig2.savefig('{:}_lightcurve.png'.format(sys.argv[1]))
	else: 
		fig.savefig('aperture.png')
		fig2.savefig('lightcurve.png')
		np.savetxt('LC.dat', tmp) 

	plt.close()
	exit()	