#!/Users/benjaminpope/anaconda2/bin/python
import numpy as np
import matplotlib.pyplot as plt
from astropy.table import Table
import scipy.optimize as optimize
from astropy.io import fits
from time import time as clock
from os.path import join, exists, abspath, basename

from halophot.halo_tools import *

from argparse import ArgumentParser

import matplotlib as mpl

mpl.style.use('seaborn-colorblind')

#To make sure we have always the same matplotlib settings
#(the ones in comments are the ipython notebook settings)

mpl.rcParams['figure.figsize']=(8.0,6.0)    #(6.0,4.0)
mpl.rcParams['font.size']=18               #10 
mpl.rcParams['savefig.dpi']= 200             #72 
mpl.rcParams['axes.labelsize'] = 16
mpl.rcParams['axes.labelsize'] = 16
mpl.rcParams['xtick.labelsize'] = 12
mpl.rcParams['ytick.labelsize'] = 12


'''-----------------------------------------------------------------
halo

This executable Python script allows you to detrend any single object using halo
photometry.

An example call is 

halo ktwo200007768-c04_lpd-targ.fits --data-dir /home/ben/Data/kepler/halo/ --name atlas -c 4 --do-plot
-----------------------------------------------------------------'''

if __name__ == '__main__':
	ap = ArgumentParser(description='halophot: K2 halo photometry with total variation.')
	ap.add_argument('fname', type=str, help='Input target pixel file name.')
	ap.add_argument('--data-dir', default='', type=str)
	ap.add_argument('--name', default='test',type=str,help='Target name')
	ap.add_argument('-c', '--campaign', metavar='C',default=4, type=int, 
		help='Campaign number')
	ap.add_argument('-o', '--order', metavar='O', type=int,default=1, 
		help='TV Order: 1 for gradient, 2 for concavity')
	ap.add_argument('-sub',  type=int,default=1, help='Subsampling parameter')
	ap.add_argument('-maxiter',  type=int,default=151, help='Maximum # iterations')
	ap.add_argument('--splits', default=None, type=lambda s:np.fromstring(s.strip('[]'), 
		sep=','), help='List of time values for kernel splits')
	ap.add_argument('--rr', default=None, type=lambda s:np.fromstring(s.strip('[]'), 
		sep=','), help='rmin, rmax (pix)')
	ap.add_argument('--quiet', action='store_true', default=False, 
		help='suppress messages')
	ap.add_argument('--save-dir', default='.', 
		help='The directory to save the output file in')
	ap.add_argument('--do-plot', action = 'store_true', default = False, \
					help = 'produce plots')
	ap.add_argument('--do-split', action = 'store_true', default = False, \
					help = 'produce plots')
	ap.add_argument('--random-init', action = 'store_true', default = False, \
					help = 'initialize search with random seed')
	ap.add_argument('--minflux', type=float,default=100., help='Minimum flux to include')
	ap.add_argument('--thresh', type=float,default=0.8, help='What fraction of saturation to throw away')
	ap.add_argument('--consensus', action = 'store_true', default = False, \
					help = 'use with subsampling to run fast and avoid overfitting')
	ap.add_argument('--analytic', action = 'store_true', default = True, \
					help = 'use analytic derivatives; orders of magnitude faster')
	ap.add_argument('--sigclip', action = 'store_true', default = False, \
					help = 'sigma-clip the final light curve')

	args = ap.parse_args()

	csplits = {j:None for j in range(16)}
	csplits[4] = [550,2200]

	if args.splits is None:
		if args.campaign in csplits.keys():
			splits = csplits[args.campaign]
		else:
			splits = None
	else:
		splits = args.splits

	if not exists(args.save_dir):
		print("Error: the save directory {:s} doesn't exist".format(args.save_dir))

	### first load your data
	fname = args.data_dir + args.fname
	tpf, ts = read_tpf(fname)

	if args.campaign == 13:
		# m1 = np.logical_or(ts['cadence']<140911,ts['cadence']>140922)
		# m2 = np.logical_and(m1,ts['cadence']<144619)
		# m3 = np.logical_or(m2,ts['cadence']>144654)
		# m4 = np.logical_and(m3,ts['cadence']<144715)
		# m = np.logical_or(m4,ts['cadence']>144726)
		# tpf,ts = tpf[m,:,:], ts[m]
		m1 = np.logical_or(ts['time']<2988.2553329814764,ts['time']>2988.494)
		#m11 = np.logical_and(m1,ts['time']<3001.9834)
		#m12 = np.logical_or(m11,ts['time']>3001.9849)
		m2 = np.logical_and(m1,ts['time']<3064.016165412577)
		m3 = np.logical_or(m2,ts['time']>3064.75)
		m4 = np.logical_and(m3,ts['time']<3065.9776255118923)
		m = np.logical_or(m4,ts['time']>3066.2225)
		tpf,ts = tpf[m,:,:], ts[m]

	if args.campaign == 10:
		m = ts['time']>2760
		tpf, ts = tpf[m,:,:], ts[m]

	if args.campaign == 7:
		m = ts['time']>2470
		tpf, ts = tpf[m,:,:], ts[m]

	print('Data loaded!')

	start = clock()

	# get annulus if necessary
	if args.rr is not None:
		rmin, rmax = args.rr
		print('Getting annulus from',rmin,'to',rmax)
		tpf = get_annulus(tpf,rmin,rmax)
		print('Using',np.sum(np.isfinite(tpf[0,:,:])),'pixels')

	if args.do_split:
		print('First doing one run to establish weights')

		tpf, newts, weights, weightmap, pixelvector = do_lc(tpf,ts,(None,None),args.sub, args.order,
			maxiter=args.maxiter,random_init=args.random_init,
			thresh=args.thresh,minflux=args.minflux,consensus=args.consensus,analytic=args.analytic,
			sigclip=args.sigclip)

		'Splitting at',splits
		# do first segment
		tpf1, ts1, w1, wm1, pv1 = do_lc(tpf, ts, (None,splits[0]), args.sub, args.order,
			maxiter=args.maxiter,w_init=weights,random_init=args.random_init,
			thresh=args.thresh,minflux=args.minflux,consensus=args.consensus,analytic=args.analytic,
			sigclip=args.sigclip)

		# do others
		tpf2, ts2, w2, wm2, pv2 = do_lc(tpf, ts, (splits[0],splits[1]), args.sub, args.order,
			maxiter=args.maxiter,w_init=weights,random_init=args.random_init,
			thresh=args.thresh,minflux=args.minflux,consensus=args.consensus,sigclip=args.sigclip)

		tpf3, ts3, w3, weightmap, pixelvector = do_lc(tpf, ts, (splits[1],None), args.sub, args.order,
			maxiter=args.maxiter,w_init=weights,random_init=args.random_init,
			thresh=args.thresh,minflux=args.minflux,consensus=args.consensus,analytic=args.analytic,
			sigclip=args.sigclip)

		## now stitch these

		newts = stitch([ts1,ts2,ts3])
	else:
		print('Not splitting')
		tpf, newts, weights, weightmap, pixelvector = do_lc(tpf,ts,(None,None),args.sub, args.order,
			maxiter=args.maxiter,random_init=args.random_init,
			thresh=args.thresh,minflux=args.minflux,consensus=args.consensus,analytic=args.analytic,
			sigclip=args.sigclip)

	print_time(clock()-start)

	time, opt_lc = newts['time'][:], newts['corr_flux'][:]

	if args.sigclip:
		good = sigma_clip(opt_lc,max_sigma=3.5)

		print('Clipping %d bad points' % np.sum(~good))
		pixelvector, newts = pixelvector[:,good], newts[good]
		opt_lc = np.dot(weights,pixelvector)

	tv1 = diff_1(opt_lc/np.nanmedian(opt_lc))/float(np.size(opt_lc))
	tv2 = diff_2(opt_lc/np.nanmedian(opt_lc))/float(np.size(opt_lc))

	print('Total variation per point (first order): %f ' % tv1)

	print('Total variation per point (second order): %f' % tv2)

	### save your new light curve!

	norm = np.size(weightmap)
	# weightmap = np.ma.array(weightmap,mask=np.isnan(weightmap))

	hdu = fits.PrimaryHDU(weightmap.T) # can't save a masked array yet so just using pixelmap
	cols = [fits.Column(name=key,format="D",array=newts[key]) for key in newts.keys()]
	tab = fits.BinTableHDU.from_columns(cols)

	hdul = fits.HDUList([hdu, tab])
	hdul.writeto('%s/%shalo_lc_o%s.fits' % (args.save_dir,args.name,args.order),overwrite=True)

	# newts.write('%s/%shalo_lc_o%s.fits' % (args.save_dir,args.name,args.order),overwrite=True)
	print('Saved halo-corrected light curve to %s/%shalo_lc_o%s.fits' % (args.save_dir,args.name,args.order))

	weightmap = np.ma.array(weightmap,mask=np.isnan(weightmap))

	if args.do_plot:
		m = (opt_lc>0.)
		plt.figure(1)
		plt.clf()
		plt.plot(newts['time'][m],opt_lc[m]/np.nanmedian(opt_lc[m]),'-')
		plt.xlabel('Time')
		plt.ylabel('Relative Flux')
		plt.title(args.name)
		plt.savefig('%s/%shalo_lc_o%s.png' % (args.save_dir,args.name,args.order))
		# plt.show()
		print('Saved halo-corrected light curve plot to %s/%shalo_lc_o%s.png' % (args.save_dir,args.name,args.order))

		plt.figure(2)
		plt.clf()
		cmap = mpl.cm.seismic

		cmap.set_bad('k',1.)
		im = np.log10(weightmap.T*norm)
		plt.imshow(im,cmap=cmap, vmin=-2*np.nanmax(im),vmax=2*np.nanmax(im),
			interpolation='None',origin='lower')
		plt.colorbar()
		plt.title('TV-min Weightmap %s' % args.name)
		plt.savefig('%s/%s_weightmap_o%s_sub%s.png' % (args.save_dir,args.name,args.order,args.sub))
		# plt.show()
		print('Weight map saved to %s/%s_weightmap_o%s_sub%s.png' % (args.save_dir,args.name,args.order,args.sub))
		
		plt.figure(3)
		plt.clf()
		cmap = mpl.cm.hot
		cmap.set_bad('k',1.)
		im = np.log10(np.nanmedian(tpf,axis=0))
		plt.imshow(im,cmap=cmap, vmin=np.nanmin(im),vmax=np.nanmax(im),
			interpolation='None',origin='lower')
		plt.colorbar()
		plt.title('%s Flux Map' % args.name)
		plt.savefig('%s/%s_fluxmap_o%s_sub%s.png' % (args.save_dir,args.name,args.order,args.sub))
		# plt.show()
		print('Flux map saved to %s/%s_fluxmap_o%s_sub%s.png' % (args.save_dir,args.name,args.order,args.sub))
