#!python


import numpy as np
import matplotlib.pyplot as plt
import numba  
from scipy.optimize import minimize
import argparse 

# Welcom messages
welcome_message = '''---------------------------------------------------
-                   massfunc                   -
-             samuel.gill@wariwck.ac.uk           -
---------------------------------------------------'''

description = '''Solve the mass function'''
parser = argparse.ArgumentParser('massfunc', description=description)

parser.add_argument('-m', 
                    '--M1',
                    help='The mass of the primary star in solar units.', 
                    default=1.0, type=float)

parser.add_argument('-i', 
                    '--incl',
                    help='The inclination of the orbital axis (in deg).', 
                    default=90, type=float) 

parser.add_argument('-e', 
                    '--ecc',
                    help='The orbital eccentricity.', 
                    default=0., type=float) 


parser.add_argument('-P', 
                    '--period',
                    help='The orbital period in days.', 
                    default=1., type=float) 

parser.add_argument('-K', 
                    '--K1',
                    help='The semi-amplitude in km/s.', 
                    default=10., type=float)                    




G = 6.67408e-11 # m3 kg-1 s-2
Msun = 1.989e30
PI = 3.14159265359


###################################################
# Fortran conversions
###################################################
@numba.njit
def sign(a,b) : 
    if b >= 0.0 : return abs(a)
    return -abs(a)

    

###################################################
# Brent minimisation
###################################################

@numba.njit
def brent(func,x1,x2, z0):
    # pars
    tol = 1e-5
    itmax = 100
    eps = 1e-5

    a = x1
    b = x2
    c = 0.
    d = 0.
    e = 0.
    fa = func(a,z0)
    fb = func(b,z0)

    fc = fb

    for iter in range(itmax):
        if (fb*fc > 0.0):
            c = a
            fc = fa
            d = b-a
            e=d   

        if (abs(fc) < abs(fb)):
            a = b
            b = c
            c = a
            fa = fb
            fb = fc
            fc = fa

        tol1 = 2.0*eps*abs(b)+0.5*tol
        xm = (c-b)/2.0
        if (abs(xm) <  tol1 or fb == 0.0) : return b

        if (abs(e) > tol1 and abs(fa) >  abs(fb)):
            s = fb/fa
            if (a == c):
                p = 2.0*xm*s
                q = 1.0-s
            else:
                q = fa/fc
                r = fb/fc
                p = s*(2.0*xm*q*(q-r)-(b-a)*(r-1.0))
                q = (q-1.0)*(r-1.0)*(s-1.0)
            
            if (p > 0.0) : q = - q
            p = abs(p)
            if (2.0*p < min(3.0*xm*q-abs(tol1*q),abs(e*q))):
                e = d
                d = p/q
            else:
                d = xm
                e = d
        else:
            d = xm
            e = d   

        a = b
        fa = fb      
         
        if( abs(d) > tol1) : b = b + d
        else : b = b + sign(tol1, xm)

        fb = func(b,z0)
    return 1

@numba.njit
def mass_function_1(e, P, K):
    K = K*1e3 # convert to m/s 
    P = P*86400 # conver days to seconds  
    return (1-e**2)*P*K**3 / (2*PI*G)


@numba.njit
def mass_function_2(M1, M2, i):
    M1 = M1*Msun
    M2 = M2*Msun
    return (M2*np.sin(i))**3 / ((M1+M2)**2) 

@numba.njit
def min_func(M2, z0):
    M1, i, e, P, K = z0
    return mass_function_1(e, P, K) - mass_function_2(M1, M2, i)



if __name__ == '__main__' : 
    # parse args
    args = parser.parse_args()

    # Print the welcome message 
    print(welcome_message)

    # Create the arguents
    z0 = np.array([args.M1, args.incl*np.pi/180., args.ecc, args.period, args.K1])

    M2 = np.linspace(0.000001, 1, 10000)
    diff = np.empty_like(M2)
    for i in range(M2.shape[0]) : diff[i] = min_func(M2[i], z0) 
    diff  = np.abs(diff)

    plt.figure(figsize=(15,5))
    plt.semilogy(M2,diff, c='k') 

    M2 = brent(min_func,0.00001,0.7, z0 )
    print('Mass of companion : {:.5f} R_sol'.format(M2))

    plt.axvline(M2, c='k', ls='--')
    plt.xlabel('M2 [M_sun]')
    plt.xlabel('f(m) residual')
    print('---------------------------------------------------')
    plt.show()