#!/anaconda2/bin/python
import numpy as np
import matplotlib.pyplot as plt
from astropy.table import Table
import scipy.optimize as optimize
from astropy.io import fits
from time import time as clock
from os.path import join, exists, abspath, basename

from halophot.halo_tools import *

from argparse import ArgumentParser

import matplotlib as mpl

mpl.style.use('seaborn-colorblind')

#To make sure we have always the same matplotlib settings
#(the ones in comments are the ipython notebook settings)

mpl.rcParams['figure.figsize']=(8.0,6.0)    #(6.0,4.0)
mpl.rcParams['font.size']=18               #10 
mpl.rcParams['savefig.dpi']= 200             #72 
mpl.rcParams['axes.labelsize'] = 16
mpl.rcParams['axes.labelsize'] = 16
mpl.rcParams['xtick.labelsize'] = 12
mpl.rcParams['ytick.labelsize'] = 12


'''-----------------------------------------------------------------
halo

This executable Python script allows you to detrend any single object using halo
photometry.

An example call is 

halo ktwo200007768-c04_lpd-targ.fits --data-dir /home/ben/Data/kepler/halo/ --name atlas -c 4 --do-plot
-----------------------------------------------------------------'''

if __name__ == '__main__':
    ap = ArgumentParser(description='halophot: K2 halo photometry with total variation.')
    ap.add_argument('fname', type=str, help='Input target pixel file name.')
    ap.add_argument('--data-dir', default='', type=str)
    ap.add_argument('--name', default='test',type=str,help='Target name')
    ap.add_argument('-c', '--campaign', metavar='C',default=4, type=int, 
        help='Campaign number')
    ap.add_argument('-o', '--order', metavar='O', type=int,default=1, 
        help='TV Order: 1 for gradient, 2 for concavity')
    ap.add_argument('-sub',  type=int,default=1, help='Subsampling parameter')
    ap.add_argument('-maxiter',  type=int,default=151, help='Maximum # iterations')
    ap.add_argument('--splits', default=None, type=lambda s:np.fromstring(s.strip('[]'), 
        sep=','), help='List of time values for kernel splits')
    ap.add_argument('--rr', default=None, type=lambda s:np.fromstring(s.strip('[]'), 
        sep=','), help='rmin, rmax (pix)')
    ap.add_argument('--quiet', action='store_true', default=False, 
        help='suppress messages')
    ap.add_argument('--save-dir', default='.', 
        help='The directory to save the output file in')
    ap.add_argument('--do-plot', action = 'store_true', default = False, \
                    help = 'produce plots')
    ap.add_argument('--do-split', action = 'store_true', default = False, \
                    help = 'produce plots')
    ap.add_argument('--random-init', action = 'store_true', default = False, \
                    help = 'initialize search with random seed')
    ap.add_argument('--minflux', type=float,default=100., help='Minimum flux to include')
    ap.add_argument('--thresh', type=float,default=0.8, help='What fraction of saturation to throw away')
    ap.add_argument('--consensus', action = 'store_true', default = False, \
                    help = 'use with subsampling to run fast and avoid overfitting')
    ap.add_argument('--analytic', action = 'store_true', default = True, \
                    help = 'use analytic derivatives; orders of magnitude faster')
    ap.add_argument('--sigclip', action = 'store_true', default = False, \
                    help = 'sigma-clip the final light curve')
    ap.add_argument('--deathstar', action = 'store_true', default = False, \
                    help = 'remove background star pixels')

    args = ap.parse_args()

    csplits = {j:None for j in range(16)}
    csplits[4] = [550,2200]

    if args.splits is None:
        if args.campaign in csplits.keys():
            splits = csplits[args.campaign]
        else:
            splits = None
    else:
        splits = args.splits

    if not exists(args.save_dir):
        print("Error: the save directory {:s} doesn't exist".format(args.save_dir))

    ### first load your data
    fname = args.data_dir + args.fname
    tpf, ts = read_tpf(fname)

    if args.campaign == 13:
        # m1 = np.logical_or(ts['cadence']<140911,ts['cadence']>140922)
        # m2 = np.logical_and(m1,ts['cadence']<144619)
        # m3 = np.logical_or(m2,ts['cadence']>144654)
        # m4 = np.logical_and(m3,ts['cadence']<144715)
        # m = np.logical_or(m4,ts['cadence']>144726)
        # tpf,ts = tpf[m,:,:], ts[m]
        m1 = np.logical_or(ts['time']<2988.2553329814764,ts['time']>2988.494)
        #m11 = np.logical_and(m1,ts['time']<3001.9834)
        #m12 = np.logical_or(m11,ts['time']>3001.9849)
        m2 = np.logical_and(m1,ts['time']<3064.016165412577)
        m3 = np.logical_or(m2,ts['time']>3064.75)
        m4 = np.logical_and(m3,ts['time']<3065.9776255118923)
        m = np.logical_or(m4,ts['time']>3066.2225)
        tpf,ts = tpf[m,:,:], ts[m]

    if args.campaign == 10:
        m = ts['time']>2760
        tpf, ts = tpf[m,:,:], ts[m]

    if args.campaign == 7:
        m = ts['time']>2470
        tpf, ts = tpf[m,:,:], ts[m]

    print('Data loaded!')

    start = clock()

    # get annulus if necessary
    if args.rr is not None:
        rmin, rmax = args.rr
        print('Getting annulus from',rmin,'to',rmax)
        tpf = get_annulus(tpf,rmin,rmax)
        print('Using',np.sum(np.isfinite(tpf[0,:,:])),'pixels')

    # destroy background stars
    if args.deathstar:
        print('Removing background stars')
        tpf = remove_stars(tpf)


    if args.do_split:
        print('First doing one run to establish weights')

        tpf, newts, weights, weightmap, pixelvector = do_lc(tpf,ts,(None,None),args.sub, args.order,
            maxiter=args.maxiter,random_init=args.random_init,
            thresh=args.thresh,minflux=args.minflux,consensus=args.consensus,analytic=args.analytic,
            sigclip=args.sigclip)

        'Splitting at',splits
        # do first segment
        tpf1, ts1, w1, wm1, pv1 = do_lc(tpf, ts, (None,splits[0]), args.sub, args.order,
            maxiter=args.maxiter,w_init=weights,random_init=args.random_init,
            thresh=args.thresh,minflux=args.minflux,consensus=args.consensus,analytic=args.analytic,
            sigclip=args.sigclip)

        # do others
        tpf2, ts2, w2, wm2, pv2 = do_lc(tpf, ts, (splits[0],splits[1]), args.sub, args.order,
            maxiter=args.maxiter,w_init=weights,random_init=args.random_init,
            thresh=args.thresh,minflux=args.minflux,consensus=args.consensus,sigclip=args.sigclip)

        tpf3, ts3, w3, weightmap, pixelvector = do_lc(tpf, ts, (splits[1],None), args.sub, args.order,
            maxiter=args.maxiter,w_init=weights,random_init=args.random_init,
            thresh=args.thresh,minflux=args.minflux,consensus=args.consensus,analytic=args.analytic,
            sigclip=args.sigclip)

        ## now stitch these

        newts = stitch([ts1,ts2,ts3])
    else:
        print('Not splitting')
        tpf, newts, weights, weightmap, pixelvector = do_lc(tpf,ts,(None,None),args.sub, args.order,
            maxiter=args.maxiter,random_init=args.random_init,
            thresh=args.thresh,minflux=args.minflux,consensus=args.consensus,analytic=args.analytic,
            sigclip=args.sigclip)

    print_time(clock()-start)

    time, opt_lc = newts['time'][:], newts['corr_flux'][:]

    if args.sigclip:
        good = sigma_clip(opt_lc,max_sigma=3.5)

        print('Clipping %d bad points' % np.sum(~good))
        pixelvector, newts = pixelvector[:,good], newts[good]
        opt_lc = np.dot(weights,pixelvector)

    tv1 = diff_1(opt_lc/np.nanmedian(opt_lc))/float(np.size(opt_lc))
    tv2 = diff_2(opt_lc/np.nanmedian(opt_lc))/float(np.size(opt_lc))

    print('Total variation per point (first order): %f ' % tv1)

    print('Total variation per point (second order): %f' % tv2)

    ### save your new light curve!

    norm = np.size(weightmap)
    # weightmap = np.ma.array(weightmap,mask=np.isnan(weightmap))

    hdu = fits.PrimaryHDU(weightmap.T) # can't save a masked array yet so just using pixelmap
    cols = [fits.Column(name=key,format="D",array=newts[key]) for key in newts.keys()]
    tab = fits.BinTableHDU.from_columns(cols)

    hdul = fits.HDUList([hdu, tab])
    hdul.writeto('%s/%shalo_lc_o%s.fits' % (args.save_dir,args.name,args.order),overwrite=True)

    # newts.write('%s/%shalo_lc_o%s.fits' % (args.save_dir,args.name,args.order),overwrite=True)
    print('Saved halo-corrected light curve to %s/%shalo_lc_o%s.fits' % (args.save_dir,args.name,args.order))

    weightmap = np.ma.array(weightmap,mask=np.isnan(weightmap))

    if args.do_plot:
        m = (opt_lc>0.)
        plt.figure(1)
        plt.clf()
        plt.plot(newts['time'][m],opt_lc[m]/np.nanmedian(opt_lc[m]),'-')
        plt.xlabel('Time')
        plt.ylabel('Relative Flux')
        plt.title(args.name)
        plt.savefig('%s/%shalo_lc_o%s.png' % (args.save_dir,args.name,args.order))
        # plt.show()
        print('Saved halo-corrected light curve plot to %s/%shalo_lc_o%s.png' % (args.save_dir,args.name,args.order))

        plt.figure(2)
        plt.clf()
        cmap = mpl.cm.seismic

        cmap.set_bad('k',1.)
        im = np.log10(weightmap.T*norm)
        plt.imshow(im,cmap=cmap, vmin=-2*np.nanmax(im),vmax=2*np.nanmax(im),
            interpolation='None',origin='lower')
        plt.colorbar()
        plt.title('TV-min Weightmap %s' % args.name)
        plt.savefig('%s/%s_weightmap_o%s_sub%s.png' % (args.save_dir,args.name,args.order,args.sub))
        # plt.show()
        print('Weight map saved to %s/%s_weightmap_o%s_sub%s.png' % (args.save_dir,args.name,args.order,args.sub))
        
        plt.figure(3)
        plt.clf()
        cmap = mpl.cm.hot
        cmap.set_bad('k',1.)
        im = np.log10(np.nansum(tpf,axis=0))
        plt.imshow(im,cmap=cmap, vmax=np.nanmax(im),
            interpolation='None',origin='lower')
        plt.colorbar()
        plt.title('%s Flux Map' % args.name)
        plt.savefig('%s/%s_fluxmap_o%s_sub%s.png' % (args.save_dir,args.name,args.order,args.sub))
        # plt.show()
        print('Flux map saved to %s/%s_fluxmap_o%s_sub%s.png' % (args.save_dir,args.name,args.order,args.sub))
