#!python

import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import lombscargle
from scipy.optimize import leastsq
from k2sc.ls import fasper
import emcee, corner
from astropy.stats import LombScargle
import sys, os, glob
import matplotlib.ticker as mticker

def phaser(t, t_zero ,period) : return ((t - t_zero)/period) - np.floor((t - t_zero)/period)

def psearch(time, flux):
    freq,power,nout,jmax,prob = fasper(time, flux, 60, 0.5)
    period = 1./freq
    m = (period > 0.05) & (period < 10) 
    period, freq, power = period[m], freq[m], power[m]
    j = np.argmax(power)

    expy = np.exp(-power)
    effm = 2.*nout/6
    fap  = expy*effm

    mfap = fap > 0.01
    fap[mfap] = 1.0-(1.0-expy[mfap])**effm

    return period, power, fap, j


def sin_function(theta, x, p):
    # 0 - Amplitude
    # 1 - phase_offset
    # 2 - zero-point 
    return theta[0]*np.sin(2*np.pi*x/p + theta[1]) + theta[2]

def lnlike_function(theta, x, y, ye, p):
    #if theta[0] < 0 : return -np.inf 
    #if (theta[1] < -2*np.pi) or (theta[1] > 2*np.pi) : return -np.inf
    
    # Get the model
    model = sin_function(theta, x, p)
    
    # Now get the loglike
    wt = 1.0/ye**2
    return -0.5*np.sum( (y-model)**2*wt - np.log(wt)  )


if __name__=="__main__":

    # Load the data
    time, mag, mag_err = np.loadtxt(sys.argv[1]).T

    # Select the number of times to pre-whiten
    N_prewhiten = 5
    f, axs = plt.subplots(nrows=N_prewhiten, ncols=2, figsize=(10,6*N_prewhiten/3))

    # Store the rms scatter 
    scatters =[]
    for i in range(N_prewhiten):
        periodp, powerp, fapp, jp = psearch(time, mag-np.mean(mag))

        # First, phase plot the scatter graph
        phase = phaser(time, 0.0, periodp[jp] )
        axs[i][0].scatter(phase, 1e3*mag, c='k', s = 3, alpha = 0.5, label=None)

        # Now do the emcee
        ndim = 3
        nwalkers = 4*ndim
        p0 = np.array([np.random.normal([5e-3, 0.0, 0.0],1e-3).tolist() for i in range(nwalkers)])
        sampler = emcee.EnsembleSampler(nwalkers, ndim, lnlike_function, args=[time, mag, mag_err,  periodp[jp]])
        sampler.run_mcmc(p0, 50000, progress=True)

        # Now extract the best step
        best_step = sampler.get_chain(flat=True)[np.argmax(sampler.get_log_prob(flat=True))]
        best_model = sin_function(best_step, time, periodp[jp])
        best_phase_model = sin_function(best_step, np.linspace(0,1,1000), 1.0)
        axs[i][0].plot(np.linspace(0,1,1000, 1.0), 1e3*best_phase_model, 'r', label = 'P = {:.2f} d'.format(periodp[jp]))
        axs[i][0].legend(loc=2)
        # Now plot the highest peak
        axs[i][1].axvline(periodp[jp], ls='--', zorder=3, c='r', alpha = 0.4)

        # Now plot the periodigram
        axs[i][1].semilogx(periodp,powerp, 'k',linewidth=0.3)
        axs[i][1].xaxis.set_major_formatter(mticker.ScalarFormatter())
        for j in range(2, 6):
            axs[i][1].axvline(periodp[jp]*j, ls='--', zorder=3, c='b', alpha = 0.2)
            axs[i][1].axvline(periodp[jp]/j, ls='--', zorder=3, c='b', alpha = 0.2)

        # Now pre-whiten
        mag = mag - best_model

        # Now sort the labels 
        scatters.append(1e3*np.std(mag))

        # Now save it as a text file 
        np.savetxt('prewhitened_{:}.dat'.format(i), np.array([time, mag, mag_err]).T)

        # set title
        #axs[i][0].set_title('P = {:.2f} d'.format(periodp[jp]))
        print('Period {:.0f} is at {:.2f} with amplitude = {:.2f} mmag   (raw RMS = {:.2f} mmag)'.format(i+1, periodp[jp], 1e3*best_step[0], np.std(mag)*1e3))



    # now clean graph 
    plt.setp(axs[N_prewhiten-1][0], xlabel='Phase')
    plt.setp(axs[N_prewhiten-1][1], xlabel='Period [d]')

    for i in range(N_prewhiten):
        plt.setp(axs[i][0], ylabel='Mag [mmag]\n' + r'$\sigma$ = {:.1f} mmag'.format(scatters[i]))
        plt.setp(axs[i][1], ylabel='Power')
        axs[i][0].invert_yaxis()

    for i in range(N_prewhiten-1):
        plt.setp( axs[i][0].get_xticklabels(), visible=False)
        plt.setp( axs[i][1].get_xticklabels(), visible=False)

    plt.tight_layout()     

    plt.savefig('Modulation.png')
    plt.show()
