#!python

from bruce import lc
import emcee, corner, sys, os, numpy as np
import matplotlib.pyplot as plt 
import argparse 
from multiprocessing import Pool
from celerite.modeling import Model
from celerite import terms, GP
from scipy.stats import chisquare, sem
import matplotlib.cm as cm
np.warnings.filterwarnings('ignore')

color = "#ff7f0e"

def lc_bin(time, flux, bin_width):
        '''
        Function to bin the data into bins of a given width. time and bin_width 
        must have the same units
        '''
        
        edges = np.arange(np.min(time), np.max(time), bin_width)
        dig = np.digitize(time, edges)
        time_binned = (edges[1:] + edges[:-1]) / 2
        flux_binned = np.array([np.nan if len(flux[dig == i]) == 0 else flux[dig == i].mean() for i in range(1, len(edges))])
        err_binned = np.array([np.nan if len(flux[dig == i]) == 0 else sem(flux[dig == i]) for i in range(1, len(edges))])
        time_bin = time_binned[~np.isnan(err_binned)]
        err_bin = err_binned[~np.isnan(err_binned)]
        flux_bin = flux_binned[~np.isnan(err_binned)]   
        
        return time_bin, flux_bin, err_bin


# Phase
def phaser(time, t_zero, period) : return ((time - t_zero)/period) - np.floor((time - t_zero)/period) 


# Define the model
transit_model_bounds = dict(radius_1 = (0, 0.9), k = (0, 0.9), b = (0,2), fs = (-1,1), fc = (-1,1), SBR = (0,None), light3 = (0, None), J = (0, None))
kernel_bounds = dict(log_sigma = (-10,10), log_rho = (-10,15))

class transitmodel(Model):
    parameter_names = ("t_zero", "period", "radius_1", "k", "fs", "fc", "b", "q", "albedo", "alpha_doppler",
                        "K1", "spots", "omega_1", "ldc_law_1","ldc_1_1", "ldc_1_2", "gdc_1",
                        "SBR", "light_3", "E_tol" , "zp", "J" )

    def get_value(self, t):
        return self.zp - 2.5*np.log10(lc(t, t_zero = self.t_zero, period = self.period,
                radius_1 = self.radius_1, k=self.k, 
                fs = self.fs, fc = self.fc, 
                q=self.q, albedo = self.albedo,
                alpha_doppler=self.alpha_doppler, K1 = self.K1,
                spots = np.array(self.spots), omega_1=self.omega_1,
                incl = 180*np.arccos(self.radius_1*self.b)/np.pi,
                ld_law_1=int(self.ldc_law_1), ldc_1_1 = self.ldc_1_1, ldc_1_2 = self.ldc_1_2, gdc_1 = self.gdc_1,
                SBR=self.SBR, light_3 = self.light_3,
                E_tol=self.E_tol))


    def log_likelihood(self, t, mag, mag_err):
        model = self.get_value(t)
        wt = 1.0 / (mag_err**2 + self.J**2)
        return -0.5*np.sum((mag - model)**2*wt - np.log(wt))
        
        '''
        return lc(t, mag, mag_err, J=self.J, zp=args.zp,
                t_zero = self.t_zero, period = self.period,
                radius_1 = self.radius_1, k=self.k, 
                fs = self.fs, fc = self.fc, 
                q=self.q, albedo = self.albedo,
                alpha_doppler=self.alpha_doppler, K1 = self.K1,
                spots = np.array(self.spots), omega_1=self.omega_1,
                incl = 180*np.arccos(self.radius_1*self.b)/np.pi,
                ld_law_1=int(self.ldc_law_1), ldc_1_1 = self.ldc_1_1, ldc_1_2 = self.ldc_1_2, gdc_1 = self.gdc_1,
                SBR=self.SBR, light_3 = self.light_3,
                E_tol=self.E_tol)
        '''


def lnlike(theta, time, mag, mag_err, t_zero_ref, period_ref, theta_names ):
    # First, set the attributes 
    for i in range(len(theta)) : transit_model.set_parameter(theta_names[i], theta[i])

    # The check limits 
    if (transit_model.t_zero < t_zero_ref - 0.2*period_ref) or (transit_model.t_zero > t_zero_ref + 0.2*period_ref) : return -np.inf
    #if (transit_model.t_zero < t_zero_ref - period_ref) or (transit_model.t_zero > t_zero_ref + period_ref) : return -np.inf

    if (transit_model.period < period_ref - 1e-3) or (transit_model.period > period_ref + 1e-3) : return -np.inf 
    #if (transit_model.period < period_ref - 1e-2) or (transit_model.period > period_ref + 1e-2) : return -np.inf 

    if (transit_model.k < 0.0) or (transit_model.k > 0.8) : return -np.inf 
    if (transit_model.radius_1 < 0.0) or (transit_model.radius_1 > 0.8) : return -np.inf 
    if (transit_model.b < 0) or (transit_model.b > 1.0 + transit_model.k) : return -np.inf 
    if (transit_model.J < 0) : return -np.inf 
    if (transit_model.q < 0) : return -np.inf 
    if (transit_model.zp < -20) or (transit_model.zp > 20) : return -np.inf 
    #if (transit_model.fs < -0.5) or (transit_model.fs > 0.5) : return -np.inf 
    #if (transit_model.fc < -0.5) or (transit_model.fc > 0.5) : return -np.inf 
    if (transit_model.SBR < 0) or (transit_model.SBR > 1) : return -np.inf 

    # now return loglike 
    return transit_model.log_likelihood(time, mag, mag_err)

def lnlike_gp(theta, time, mag, mag_err, t_zero_ref, period_ref, theta_names ):
    # First, set the attributes 
    for i in range(len(theta)) : gp.set_parameter(theta_names[i], theta[i]) 

    t_zero = gp.get_parameter('mean:t_zero')
    period = gp.get_parameter('mean:period')
    if (t_zero < t_zero_ref - 0.2*period_ref) or (t_zero > t_zero_ref + 0.2*period_ref) : return -np.inf
    if (period < period_ref - 1e-3) or (period > period_ref + 1e-3) : return -np.inf 

    # We don't need to do bounds, that should already be done 
    lp = gp.log_prior()
    if not np.isfinite(lp) : return -np.inf
    return gp.log_likelihood(mag) + lp

# Welcom messages
welcome_message = '''---------------------------------------------------
-                   NGTSfit V.2                   -
-             samuel.gill@wariwck.ac.uk           -
---------------------------------------------------'''

description = '''A program to fit binary star observations elegantly. 
Use the -h flag to see all available options for the fit. For any questions, 
please email samuel.gill@warwick.ac.uk'''

emcee_message = '''---------------------------------------------------
-                   emcee                         -
---------------------------------------------------'''


# Argument parser
parser = argparse.ArgumentParser('ngtsfit', description=description)
#parser.add_argument('-t', 
#                help='The transit epoch in arbritraty time units consisting with the input file.', 
#                dest="t_zero", 
#                action='store')

parser.add_argument("filename",
                    help='The filename of the binary star information')


parser.add_argument('-a', 
                    '--t_zero',
                    help='The transit epoch in arbritraty time units consisting with the input file.', 
                    default=0.0, type=float)

parser.add_argument('-b', 
                    '--period',
                    help='The orbital period in arbritraty time units consisting with the input file.',
                    default=1.0, type=float)  

parser.add_argument('-c', 
                    '--radius_1',
                    help='The radius of star 1 in units of the semi-major axis, a.',
                    default=0.2, type=float)  

parser.add_argument('-d', 
                    '--k',
                    help='The ratio of the radii of star 2 and star 1 (R2/R1).',
                    default=0.2, type=float)  
    
parser.add_argument('-e', 
                    '--b',
                    help='The impact parameter of the orbit (incl = arccos(radius_1*b).',
                    default=0., type=float)  

parser.add_argument('-f', 
                    '--zp',
                    help='The photometric zero-point.',
                    default=0., type=float) 

parser.add_argument('-g', 
                '--limb_darkening_law',
                help='The limb-darkening law for star 1. Options are: 1) quadratic, 2) power2 .',
                default='power2') 

parser.add_argument('-i', 
                '--ldc_1',
                help='The first limb-darkening coefficient [default 0.5].',
                default=0.8, type=float)

parser.add_argument('-j', 
                '--ldc_2',
                help='The second limb-darkening coefficient [default 0.2].',
                default=0.8, type=float) 


parser.add_argument('-k', '--spots', 
                    help='The information for spots on star 1, if required', 
                    nargs='+', 
                    type=float, 
                    default=[])

parser.add_argument('-l', 
                '--gdc_1',
                help='The gravity darkening coefficient of star 1 [default 0.4].',
                default=0.4, type=float) 

parser.add_argument('-m', 
                '--q',
                help='The mass ratio of star 2 to star 1 [default 0.].',
                default=0., type=float) 

parser.add_argument('-n', 
                '--albedo',
                help='The albedo of the secondary [default 0.]',
                default=0., type=float)

parser.add_argument('-o', 
                '--alpha_doppler',
                help='The alpha_doppler parameter.',
                default=0., type=float)

parser.add_argument('-p', 
                '--K1',
                help='The semi-amplitude [km/s] of radial velocity (used for ellipsoidal variation and rv) [default 10].',
                default=10, type=float)

parser.add_argument('-q', 
                '--light_3',
                help='The third light in the system [default 0.].',
                default=0., type=float)

parser.add_argument('-r', 
                '--SBR',
                help='The surface-brightness ratio [default 0.].',
                default=0., type=float)

parser.add_argument('-w', 
                '--J',
                help='The additional Jitter [default 0.].',
                default=0., type=float)

parser.add_argument('--trial', action="store_true", default=False)

parser.add_argument('-s', 
                '--plot_alpha',
                help='The plot alpha',
                default=1., type=float)


parser.add_argument('-t', '--fitpars', 
                    help='A comma seperated list of free parameters', 
                    nargs='+',
                    default=[])

parser.add_argument('--emcee', action="store_true", default=False)
parser.add_argument('-u', 
                '--emcee_steps',
                help='The number of emcee steps [default 1000]',
                default=10000, type=int)

parser.add_argument('-v', 
                '--emcee_burn_in',
                help='The number of emcee steps to discard [default 500]',
                default=5000, type=int)

parser.add_argument('-x', 
                '--threads',
                help='The number of threads to use [default 1]',
                default=10, type=int)

parser.add_argument('-y', 
        '--bin',
        help='The bin width from which to bin the lightcurve, in minutes [default=None].', 
        default=0.0, type=float)   


parser.add_argument('--gp', action="store_true", default=False)
 

parser.add_argument('-ab', 
        '--log_sigma',
        help='log sigma for GP', 
        default=0.0, type=float)           

parser.add_argument('-ac', 
        '--log_rho',
        help='log rho for GP', 
        default=0.0, type=float)   


parser.add_argument('-ad', 
        '--fs',
        help='fs for eccentricity = sin(omega)*root(e)', 
        default=0.0, type=float)  

parser.add_argument('-ae', 
        '--fc',
        help='fs for eccentricity = cos(omega)*root(e)', 
        default=0.0, type=float)  

parser.add_argument('-af', 
        '--omega_1',
        help='Ratio of angular rotation of the host star to orbiting body.  ', 
        default=1.0, type=float) 

'''
# Emcee function 
def lnlike(theta, time, mag, mag_err, theta_names, t_zero_ref, period_ref):


    # Make a copy of the args and copy over the values
    args1 = np.copy(args).all()
    for i in range(len(theta_names)) : args1.__setattr__(theta_names[i], theta[i])

    if (args1.t_zero < t_zero_ref - 0.2*period_ref) or (args1.t_zero > t_zero_ref + 0.2*period_ref) : return -np.inf
    if (args1.period < period_ref - 1e-3) or (args1.period > period_ref + 1e-3) : return -np.inf 
    if (args1.k < 0.0) or (args1.k > 0.8) : return -np.inf 
    if (args1.radius_1 < 0.0) or (args1.radius_1 > 0.8) : return -np.inf 
    if (args.b < 0) or (args.b > 1.0 + args.k) : return -np.inf 
    if (args.J < 0) : return -np.inf 
    if (args.q < 0) : return -np.inf 

    # Return the loglike
    log =  lc(time, mag=mag, mag_err=mag_err, J=args1.J, zp = args1.zp,
        t_zero = args1.t_zero, period = args1.period,
        radius_1 = args1.radius_1, k=args1.k, 
        fs = 0.0, fc = 0.0, 
        q=args1.q, albedo = args1.albedo,
        alpha_doppler=args1.alpha_doppler, K1 = args1.K1,
        spots = np.array(args1.spots), omega_1=1., nspots=nspots,
        incl = 180*np.arccos(args1.radius_1*args1.b)/np.pi,
        ldc_law_1=ld_law, ldc_1_1 = args1.ldc_1, ldc_1_2 = args1.ldc_2, gdc_1 = args1.gdc_1,
        SBR=args1.SBR, light_3 = args1.light_3,
        Accurate_t_ecl=0, t_ecl_tolerance=1e-5, Accurate_Eccentric_Anomaly=1, E_tol=1e-5,
        nthreads=1)

    if np.isnan(log) : return -np.inf 
    else : return log
'''

if __name__ == "__main__":
    args = parser.parse_args()

    # Print the welcome message 
    print(welcome_message)

    # Check for a file
    if len(sys.argv) == 1 : raise ValueError('No file specified')

    # Now load the datafile
    try:
        time, mag, mag_err = np.loadtxt(args.filename).T
    except ValueError:
        time, mag, mag_err, aaaaa, aaaaaaa = np.loadtxt(args.filename).T

    # now mask 
    mask = np.isnan(mag) | np.isinf(mag) | np.isnan(mag_err) | np.isinf(mag_err)
    try:
        time, mag, mag_err = np.loadtxt(args.filename)[~mask].T
    except ValueError:
        time, mag, mag_err, aaaaa, aaaaaaa = np.loadtxt(args.filename)[~mask].T

    print('Loaded {:,} lines from {:}'.format(len(time),args.filename))
    if args.bin > 0 : 
        time, mag, mag_err = lc_bin(time, mag, args.bin/24./60.)
        print('\treduced to {:} lines with {:}-minute binning'.format(len(time), args.bin))

    time = time.astype(np.float64)
    mag = mag.astype(np.float64) 
    mag_err = mag_err.astype(np.float64)

    print('---------------------------------------------------')

    # Report 
    print('System parameters:')
    print('\tt_zero   : {:}'.format(args.t_zero))
    print('\tperiod   : {:}'.format(args.period))
    print('\tradius_1 : {:}'.format(args.radius_1))
    print('\tk        : {:}'.format(args.k))
    print('\tb        : {:} [{:.2f} deg]'.format(args.b, 180*np.arccos(args.radius_1*args.b)/np.pi))
    print('\tzp       : {:}'.format(args.zp))
    print('\tld_law   : {:}'.format(args.limb_darkening_law))
    print('\t\t   -------')
    print('\t\t   ldc_1 {:}'.format(args.ldc_1))
    print('\t\t   ldc_2 {:}'.format(args.ldc_2)) 
    print('\t\t   gdc_1 {:}'.format(args.gdc_1)) 
    nspots = len(args.spots)//4
    print('\tspots    : {:}'.format(nspots))
    if (nspots > 0):
        for i in range(nspots):
            print('\t\t   Spot ', i, '\n\t\t   -------')
            print('\t\t   longitude of spot centre (radians) = {:}'.format(args.spots[4*i + 0]))
            print('\t\t   latitude of spot centre (radians)  = {:}'.format(args.spots[4*i + 1]))
            print('\t\t   angular radius of spot (radians)   = {:}'.format(args.spots[4*i + 2]))
            print('\t\t   Spot contrast ratio (a=Is/Ip).     = {:}'.format(args.spots[4*i + 3])) 
    print('\tq      : {:}'.format(args.q)) 
    print('\talbedo : {:}'.format(args.albedo)) 
    print('\talpha  : {:}'.format(args.alpha_doppler)) 
    print('\tK1     : {:}'.format(args.K1)) 
    print('\tfs     : {:}'.format(args.fs)) 
    print('\tfc    : {:}'.format(args.fc)) 
    print('\tomega_1    : {:}'.format(args.omega_1)) 
    print('\tlight_3: {:}'.format(args.light_3)) 
    print('\tsbr    : {:}'.format(args.SBR))
    print('\ttrial  : {:}'.format(args.trial))
    print('\tFree parameters ({:}):'.format(len(args.fitpars)))
    for i in range(len(args.fitpars)):
        print('\t\t{:}'.format(args.fitpars[i]))
    print('\tGP : {:}'.format(args.gp))
    print('\tlogsigma : {:}'.format(args.log_sigma))
    print('\tlog_rho : {:}'.format(args.log_rho))
    print('\tThreads  : {:}'.format(args.threads))

    if args.limb_darkening_law =='uniform' : ld_law = 0
    if args.limb_darkening_law =='quadratic' : ld_law = 1
    if args.limb_darkening_law =='power2'    : ld_law = 2

    # Now let's do a trial if needed
    # First, let's initialse the transit model 
    transit_model = transitmodel(t_zero = args.t_zero, period = args.period, radius_1 = args.radius_1, k=args.k, 
                                fs=args.fs,fc = args.fc, b = args.b, q = args.q, albedo = args.albedo, alpha_doppler = args.alpha_doppler,
                                K1 = args.K1, spots=  np.array(args.spots), omega_1=1.0, ldc_law_1 = ld_law,
                                ldc_1_1 = args.ldc_1, ldc_1_2 = args.ldc_2, gdc_1 = args.gdc_1, SBR  = args.SBR, 
                                light_3 = args.light_3, 
                                E_tol = 1e-4, zp = args.zp, bounds=transit_model_bounds, J=args.J)


    if args.gp : 
        kernel = terms.Matern32Term(log_rho = args.log_rho, log_sigma = args.log_sigma, bounds = kernel_bounds)
        gp = GP(kernel, mean=transit_model, fit_mean=True)
        gp.compute(time, mag_err)
                                
    if args.trial:
        # Plotting pre-processing 

        if args.gp : 
            mu, var = gp.predict(mag, time, return_var=True)
            std = np.sqrt(var) 

            # First, plot the model
            fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(15,10))
            ax1.scatter(time, mag, c='k', s=10, alpha=args.plot_alpha)
            ax1.fill_between(time, mu+std, mu-std, color=color, alpha=0.3, edgecolor="none")
            ax1.invert_yaxis()
            ax1.set_ylabel('Mag')
            ax2.set_xlabel('Time [d]')
            ax1.set_title('$\\chi^2_r$ : {:.6}'.format(-2*(gp.log_likelihood(mag) + gp.log_prior())/len(time)))

            # Then plot the data
            detrended =  mag - mu 
            phase = phaser(time ,args.t_zero, args.period) 
            transit_model.set_parameter('t_zero', 0.0)
            transit_model.set_parameter('period', 1.0) 
            transit_model.set_parameter('zp', 0.)
            ax2.scatter(phase, detrended + transit_model.get_value(phase), c='k', s=10)
            ax2.scatter(phase-1, detrended + transit_model.get_value(phase), c='k', s=10)

            phase_time = np.linspace(-0.2,0.8,10000)
            #ax2.plot(phase_time, transit_model.get_value(phase_time), 'r')
            ax2.fill_between(phase_time, transit_model.get_value(phase_time)+np.median(std), transit_model.get_value(phase_time)-np.median(std), color=color, alpha=0.3, edgecolor="none")

            '''
  
            phase_time = np.linspace(-0.4,0.4,10000) 
            phase_width = -np.min(phase_time[(transit_model.get_value(phase_time) > args.zp)])
            mask = (phase > phase_width ) & (phase < (1-phase_width))
        
            ax2.scatter(phase, mag - np.interp(time,   time[mask], mu[mask]), c='k', s=10) # out
            ax2.scatter(phase-1, mag - np.interp(time,   time[mask], mu[mask]), c='k', s=10) # out

            phase_time = np.linspace(-0.2,0.8, 1000)
            transit_model.set_parameter('t_zero', 0.)
            transit_model.set_parameter('period', 1.)
            plt.plot(phase_time, transit_model.get_value(phase_time), c=color)
            '''

            ax2.set_xlim(-0.1,0.1)
            ax2.invert_yaxis()
            ax2.set_ylabel('Mag')
            ax2.set_xlabel('Phase')

            plt.tight_layout()
            plt.show()
    
            # Reset
            transit_model.set_parameter('t_zero', args.t_zero)
            transit_model.set_parameter('period', args.period)
            transit_model.set_parameter('zp', args.zp)


        else:
            transit_model.set_parameter('t_zero', 0.0)
            transit_model.set_parameter('period', 1.0)

            phase = phaser(time ,args.t_zero, args.period) 
            phase_time = np.linspace(-0.2,0.8, 10000)
            plt.figure(figsize=(15,5))
            plt.scatter(phase, mag,   c='k', s=10, alpha=args.plot_alpha)
            plt.scatter(phase-1, mag, c='k', s=10, alpha=args.plot_alpha)
            plt.plot(phase_time, transit_model.get_value(phase_time), color)
            plt.gca().invert_yaxis() 
            plt.tight_layout()
            plt.xlim(-0.2,0.8)
            plt.xlabel('Phase')
            plt.ylabel('Mag')
            plt.title('$\\chi^2_r$ : {:.6}'.format(-2*transit_model.log_likelihood(time, mag, mag_err)/len(time)))
            plt.show()

            # Reset
            transit_model.set_parameter('t_zero', args.t_zero)
            transit_model.set_parameter('period', args.period)


    if args.emcee:
        # first, let's validat arguments 
        print(emcee_message)
        ndim = len(args.fitpars)
        for i in range(len(args.fitpars)):
            if not hasattr(args, args.fitpars[i]) : raise ValueError('Parameter "{:}" is not a valid identifier.'.format(args.fitpars[i]))

        nwalkers = 4*ndim 
        theta = []
        for i in range(len(args.fitpars)) : theta.append(float(eval('args.{:}'.format(args.fitpars[i]))))
        p0 = np.array([np.random.normal(theta, 1e-5).tolist() for i in range(nwalkers)]) 

        # Set up the backend
        # Don't forget to clear it in case the file already exists
        filename = "emcee_output.h5"
        backend = emcee.backends.HDFBackend(filename)
        backend.reset(nwalkers, ndim)

        with Pool(int(args.threads)) as pool:
            if not args.gp : 
                sampler = emcee.EnsembleSampler(nwalkers, ndim, lnlike, args = (time, mag, mag_err, args.t_zero, args.period, args.fitpars), backend=backend, pool=pool)

            else:
                for i in range(len(args.fitpars)):
                    if (args.fitpars[i] == 'log_sigma') or (args.fitpars[i] == 'log_rho') : args.fitpars[i] =  'kernel:' + args.fitpars[i]
                    else                                                                  : args.fitpars[i] =  'mean:'   + args.fitpars[i]
                sampler = emcee.EnsembleSampler(nwalkers, ndim, lnlike_gp, args = (time, mag, mag_err, args.t_zero, args.period, args.fitpars), backend=backend, pool=pool)
            sampler.run_mcmc(p0, args.emcee_steps, progress=True) 

        fig_chain, axes = plt.subplots(ndim, figsize=(6, 3*ndim))
        samples = sampler.get_chain()
        for i in range(ndim):
            ax = axes[i] 
            ax.semilogx(samples[:,:,i], 'k', alpha = 0.3)
            ax.set_xlim(0,len(samples))
            ax.set_ylabel(args.fitpars[i]) 
        fig_chain.tight_layout()
        fig_chain.savefig('chain.png')
        plt.close(fig_chain)


        samples = sampler.get_chain(flat=True, discard=args.emcee_burn_in)
        logs = sampler.get_log_prob(flat=True, discard=args.emcee_burn_in) 

        best_idx = np.argmax(logs) 
        best_step = samples[best_idx] 
        low_err = best_step - np.percentile(samples, 16, axis=0)
        high_err = np.percentile(samples, 84, axis=0) - best_step

        print('Best result:')
        output_file = open('results.txt', 'w') 
        for i in range(ndim) : 
            print('{:>15} = {:.5f} + {:.5f} - {:.5f}'.format(args.fitpars[i], best_step[i], high_err[i], low_err[i]))
            output_file.write('{:>15},{:.5f},{:.5f},{:.5f}\n'.format(args.fitpars[i], best_step[i], high_err[i], low_err[i]))
        output_file.close() 


        # now make the corner
        fig_corner = corner.corner(samples, labels=args.fitpars, truths = best_step)
        fig_corner.savefig('corner.png')
        plt.close(fig_corner)



        # Now get the best model 
        if args.gp:
            # First, set the parameter
            for i in range(len(args.fitpars)):
                gp.set_parameter(args.fitpars[i], best_step[i]) 
                if 'mean:' in args.fitpars[i] :
                    transit_model.set_parameter(args.fitpars[i][5:], best_step[i]) 
            transit_model.set_parameter('zp', 0.)

            
            mu, var = gp.predict(mag, time, return_var=True)
            np.save('muvar', np.array([mu,var]))
            std = np.sqrt(var) 

            # First, plot the model
            fig, (ax1, ax2) = plt.subplots(nrows=2, ncols=1, figsize=(15,10))
            ax1.scatter(time, mag, c='k', s=10, alpha=args.plot_alpha)
            ax1.fill_between(time, mu+std, mu-std, color=color, alpha=0.3, edgecolor="none")
            ax1.invert_yaxis()
            ax1.set_ylabel('Mag')
            ax2.set_xlabel('Time [d]')
            ax1.set_title('$\\chi^2_r$ : {:.6}'.format(-2*(gp.log_likelihood(mag) + gp.log_prior())/len(time)))

            # Then remove the continuum
            phase = phaser(time ,args.t_zero, args.period) 
            transit_model.set_parameter('t_zero', 0.0)
            transit_model.set_parameter('period', 1.0)  
            phase_time = np.linspace(-0.4,0.4,10000) 
            phase_width = -np.min(phase_time[(transit_model.get_value(phase_time) > args.zp)])
            mask = (phase > phase_width ) & (phase < (1-phase_width))
        
            ax2.scatter(phase, mag - np.interp(time,   time[mask], mu[mask]), c='k', s=10) # out
            ax2.scatter(phase-1, mag - np.interp(time,   time[mask], mu[mask]), c='k', s=10) # out

            phase_time = np.linspace(-0.2,0.8, 10000)
            transit_model.set_parameter('t_zero', 0.)
            transit_model.set_parameter('period', 1.)
            plt.plot(phase_time, transit_model.get_value(phase_time), c=color)


            ax2.set_xlim(-0.2,0.8)
            ax2.invert_yaxis()
            ax2.set_ylabel('Mag')
            ax2.set_xlabel('Phase')

            plt.tight_layout()
            plt.savefig("best_model.png")
            plt.close()

            
            # Reset
            transit_model.set_parameter('t_zero', args.t_zero)
            transit_model.set_parameter('period', args.period)
            transit_model.set_parameter('zp', args.zp)


        else:
            # First, set the parameter
            for i in range(len(args.fitpars)):
                transit_model.set_parameter(args.fitpars[i], best_step[i]) 
            transit_model.set_parameter('t_zero', 0.0)
            transit_model.set_parameter('period', 1.0)

            phase = phaser(time ,args.t_zero, args.period) 
            phase_time = np.linspace(-0.2,0.8, 10000)
            plt.figure(figsize=(15,5))
            plt.scatter(phase, mag,   c='k', s=10, alpha=args.plot_alpha)
            plt.scatter(phase-1, mag, c='k', s=10, alpha=args.plot_alpha)
            plt.plot(phase_time, transit_model.get_value(phase_time), color)
            np.save('phasemodel', np.array([phase, mag, phase_time, transit_model.get_value(phase_time)]))
            plt.gca().invert_yaxis() 
            plt.tight_layout()
            plt.xlim(-0.2,0.8)
            plt.xlabel('Phase')
            plt.ylabel('Mag')
            plt.title('$\\chi^2_r$ : {:.6}'.format(-2*transit_model.log_likelihood(time, mag, mag_err)/len(time)))
            plt.savefig("best_model.png")
            plt.xlim(-0.05,0.05)
            plt.savefig("best_model_zoom.png")

            plt.close()

            # Reset
            transit_model.set_parameter('t_zero', args.t_zero)
            transit_model.set_parameter('period', args.period)