#!python
import matplotlib
matplotlib.use('TkAgg')

import numpy as np
import matplotlib.pyplot as py
import os
import argparse
import warnings
import sys

from FAST import plotparams
from scipy import optimize

def sms(x,params):
    Vmax = params[0]
    Ks   = params[1]
    
    return Vmax/(1.0+Vmax*x/Ks)

def sms_utr_half(params):
    Vmax = params[0]
    Ks   = params[1]
    
    utr_half = Ks/Vmax
    
    return utr_half,Vmax

def err_sms(params,x,y):
    return y - sms(x,params)

def read_stats_files(fname):
    #output dictionary
    stats_out = []
    g     = open(fname,'r')
    lines = g.readlines()
    g.close()
    
    #Valid data list
    valid_lines = []
    
    if len(lines) > 1:
        for i in range(1,len(lines)):
            line      = lines[i]
            entries   = line[1:].strip().split('\t')
            slide_num = int(entries[0])
            exp_num   = int(entries[1])
            fname     = entries[2].strip()
            protein   = entries[3].strip()
            
            #Rest of the entries are data
            data      = [slide_num,exp_num,fname,protein] + [float(x) for x in entries[4:]]
            stats_out.append(data)
            
            #The data line is valid if the line does not start with #
            if not line[0] == '#':
                valid_lines.append(i-1)

    return np.array(valid_lines),stats_out

#Suppress all the warnings
warnings.filterwarnings("ignore")

#Definition of the program
usage            = ['%(prog)s -d [DIRECTORY]',
                    '------------------------------------------------------------------------',
                    'LIMA v1.0: Loaded In vitro Motility Assay',
                    '05/15/2015',
                    'Tural Aksel',
                    '',
                    'LIMA extracts the loaded motility result from FAST outputs',
                    'For bugs and other problems please contact Tural Aksel at turalaksel@gmail.com',
                    '------------------------------------------------------------------------'
                    ]

#Create the parser
parser = argparse.ArgumentParser(description='',usage='\n'.join(usage)) 
parser.add_argument('-d' ,default = None , help='top directory of the output files to be analyzed')
parser.add_argument('-amin' ,default = 0.0  , type=float              , help='minimum load concentration for analysis')
parser.add_argument('-amax' ,default = 0.0  , type=float              , help='maximum load concentration for analysis')
parser.add_argument('-pmin' ,default = 0.0  , type=float              , help='minimum load concentration for plotting')
parser.add_argument('-pmax' ,default = 0.0  , type=float              , help='maximum load concentration for plotting')
parser.add_argument('-p'  ,nargs   = '*'  , default = None, type=str  , help='protein names to be analyzed')
parser.add_argument('-r'  ,default = True, action = 'store_false'     , help='print fit parameters')
parser.add_argument('-g'  ,default = None , nargs = '*', type=float   , help='initial guess for the parameters to be fitted')
parser.add_argument('-cl' ,nargs   = '*'  , default = None, type=str  , help='plotting colors')

args = parser.parse_args()
args_dict = vars(args)
parser.print_help()

#Get main directory
main_dir           = args_dict['d']
max_load_analysis  = args_dict['amax']
min_load_analysis  = args_dict['amin']
max_load_plot      = args_dict['pmax']
min_load_plot      = args_dict['pmin']
protein_names      = args_dict['p']
print_params       = args_dict['r']
init_params        = args_dict['g']
colors             = args_dict['cl']

#Plotting colors
if colors == None:
    colors        = ['b','r','g','c','m','y']

#Fit function parameters
mobile_init_params = [100, 0.2]
fit_function      = sms
fit_err_function  = err_sms
fit_utr_half      = sms_utr_half
fit_function_name = '_stop_model'

#By default the load molecule is utrophin
load_molecule  = r'utrophin'

#Check if the last character is '/' - if yes, remove it
if main_dir != None and len(main_dir) > 0 and main_dir[-1] == '/':
    main_dir = main_dir[:-1]

#Check if the directory exists
if main_dir == None or not os.path.isdir(main_dir):
    sys.exit("Directory or file doesn't exist. Program is exiting.")

##Combined-averaged data
all_mean_file      = main_dir+'/combined/MEAN_values.txt'
all_std_file       = main_dir+'/combined/SEM_values.txt'

if not os.path.isfile(all_mean_file):
    sys.exit("Combined analysis file does not exist. Program is exiting.")

#Read the mean data
valid,std_data  = read_stats_files(all_std_file)
valid,mean_data = read_stats_files(all_mean_file)

#Header name for the files generated
header = main_dir+'/combined/lima'
if not os.path.exists(header):
    os.mkdir(header)

#Get only valid data
header_data    = [mean_data[i][:9] for i in valid]
slide_list     = np.array([int(x[0]) for x in header_data])
expnum_list    = np.array([int(x[1]) for x in header_data])
fname_list     = np.array([str(x[2]) for x in header_data])
protein_list   = np.array([x[3] for x in header_data])
utr_list       = np.array([float(x[6]) for x in header_data])

#Averaged data
mean_data     = np.array([mean_data[i][7:] for i in valid])
std_data      = np.array([std_data[i][7:]  for i in valid])

#Get protein names
proteins      = sorted(set(protein_list))

#Utrophin set for 
utr_set       = sorted(set(utr_list))

#If maximum utrophin for plot/analysis is not entered, retrieve from the data
if max_load_analysis == 0.0:
    max_load_analysis = max(utr_set)

if max_load_plot == 0.0:
    max_load_plot = max_load_analysis

#Pick only proteins you want to analyze
if not protein_names == None:
    proteins = protein_names
else:
    protein_names = proteins

#Cap-size
cap_size = 25

#Parameters tail
params_tail = ''
if print_params:
    params_tail = '_with_parameters'

#Tail for the plots generated
tail = params_tail

#Plot dimensions
x_plot,y_plot = plotparams.get_figsize(1200)

#All plots in one canvas
py.figure(0,figsize=(2*x_plot,1*y_plot))

#Percent mobile - avg
lines        = []
ref_velocity = 0

for i in range(len(proteins)):
    py.subplot(121)
    valid_pro          = np.nonzero(protein_list == proteins[i])[0]
    
    mean_data_filtered = mean_data[valid_pro,:]
    std_data_filtered  = std_data[valid_pro,:]
    
    #Utrophin concentration
    utr_conc           =  utr_list[valid_pro]
    
    maxvelocity        =  mean_data_filtered[:,0]
    std_maxvelocity    =  std_data_filtered[:,0]
    
    mean_percent_stuck =  mean_data_filtered[:,1]
    std_percent_stuck  =  std_data_filtered[:,1]
    
    mean_MVEL          =  mean_data_filtered[:,2]
    std_MVEL           =  std_data_filtered[:,2]
    
    mean_MVIS          =  mean_data_filtered[:,5]
    std_MVIS           =  std_data_filtered[:,5]
    
    #Slide and experiment numbers
    slide_nums         = slide_list[valid_pro]
    exp_nums           = expnum_list[valid_pro]
    
    #Combine data for in-place sorting
    full_data = np.array([[utr_conc[j],mean_percent_stuck[j],std_percent_stuck[j],mean_MVEL[j],std_MVEL[j],mean_MVIS[j],std_MVIS[j],maxvelocity[j],std_maxvelocity[j],slide_nums[j],exp_nums[j]] for j in range(len(utr_conc))],dtype='f8')
    full_data = full_data.view('f8,f8,f8,f8,f8,f8,f8,f8,f8,f8,f8')
    full_data.sort(order='f0',axis=0)
    
    utr                 = full_data['f0'].flatten()
    mean_frac_mobile    = 1.0 - full_data['f1'].flatten()/100.0
    std_frac_mobile     = full_data['f2'].flatten()/100.0
    
    mean_MVEL           = full_data['f3'].flatten()
    std_MVEL            = full_data['f4'].flatten()
    
    mean_MVIS           = full_data['f5'].flatten()
    std_MVIS            = full_data['f6'].flatten()
    
    maxvelocity         = full_data['f7'].flatten()
    std_maxvelocity     = full_data['f8'].flatten()
    
    slide_nums          = full_data['f9'].flatten()
    exp_nums            = full_data['f10'].flatten()
    
    #Reference velocity is the maximum velocity at 0 nM utrophin
    ref_velocity          = maxvelocity[0]
    
    mobile_correction     = mean_MVEL/ref_velocity
    mean_frac_time_mobile = mean_frac_mobile*mobile_correction
    std_frac_time_mobile  = std_frac_mobile*mobile_correction
    
    #Analyze only the data less than a utrophin concentration
    valid_utr             = np.nonzero((utr <= max_load_analysis)*(utr >= min_load_analysis))[0]
    utr                   = utr[valid_utr]
    
    maxvelocity           = maxvelocity[valid_utr]
    std_maxvelocity       = std_maxvelocity[valid_utr]
    
    mean_frac_mobile      = mean_frac_mobile[valid_utr]
    std_frac_mobile       = std_frac_mobile[valid_utr]
    
    mean_MVEL             = mean_MVEL[valid_utr]
    std_MVEL              = std_MVEL[valid_utr]
    
    mean_MVIS             = mean_MVIS[valid_utr]
    std_MVIS              = std_MVIS[valid_utr]
    
    mean_frac_time_mobile = mean_frac_time_mobile[valid_utr]
    std_frac_time_mobile  = std_frac_time_mobile[valid_utr]
    
    slide_nums            = slide_nums[valid_utr]
    exp_nums              = exp_nums[valid_utr]
    
    num_points            = len(utr)
    
    if num_points == 0:
        sys.exit("No data points for plotting!.Exiting.")
    
    #Plot percent time mobile
    py.subplot(121)
    py.errorbar(utr,mean_frac_time_mobile*100,yerr=std_frac_time_mobile*100,marker='o',color=colors[i],linestyle='None',capsize=cap_size)
    
    #Plot MVIS
    py.subplot(122)
    py.errorbar(utr,mean_MVIS,yerr=std_MVIS,marker='o',color=colors[i],linestyle='None',capsize=cap_size)
    
    data_combined = np.hstack((np.vstack(utr),np.vstack(mean_frac_mobile)*100.0,np.vstack(std_frac_mobile)*100.0))
    np.savetxt(header+'/'+proteins[i]+'_percent_mobile.txt',data_combined)
    
    data_combined = np.hstack((np.vstack(utr),np.vstack(mean_frac_time_mobile)*100.0,np.vstack(std_frac_time_mobile)*100.0))
    np.savetxt(header+'/'+proteins[i]+'_percent_time_mobile.txt',data_combined)
    
    data_combined = np.hstack((np.vstack(utr),np.vstack(mean_MVIS),np.vstack(std_MVIS)))
    np.savetxt(header+'/'+proteins[i]+'_MVIS.txt',data_combined)
    
    data_combined = np.hstack((np.vstack(utr),np.vstack(mean_MVIS),np.vstack(std_MVIS)))
    np.savetxt(header+'/'+proteins[i]+'_MVEL.txt',data_combined)
    
    #There should be at least 3 data points for fitting
    if num_points < 3:
        continue
    
    #Back to percent time mobile panel
    py.subplot(121)
    
    #Fit model to data
    sim_utr= np.linspace(min_load_analysis,max_load_analysis,1000)
    
    params = mobile_init_params
    best_params,success = optimize.leastsq(fit_err_function,params,args=(utr,mean_frac_time_mobile),maxfev=1000)
    
    #The parameters relevant to force production
    V0  = best_params[0]
    Ks  = best_params[1]
    
    #value at Ksinv 
    X_val = fit_function(Ks,best_params)
    
    #Plot line
    line       = py.plot(sim_utr,fit_function(sim_utr,best_params)*100,color=colors[i],linestyle='-')
    
    if print_params:
        py.text(Ks,50-i*10,r'$%.2f nM^{K_S}$'%(Ks),color=colors[i],fontsize=50)
    lines.append(line[0])
    
    #Residuals
    residuals  = fit_err_function(best_params,utr,mean_frac_time_mobile)
    
    #Plot half-load/half-value
    py.plot([Ks,Ks],[X_val*100,0.0],color=colors[i],linestyle='--',linewidth=10)
    
    data_combined = np.hstack((np.vstack(utr),np.vstack(residuals)*100.0))
    np.savetxt(header+'/'+proteins[i]+'_percent_time_mobile_fit_residuals.txt',data_combined)
    
    data_combined = np.hstack((np.vstack(sim_utr),np.vstack(fit_function(sim_utr,best_params))*100))
    np.savetxt(header+'/'+proteins[i]+'_percent_time_mobile_simulated.txt',data_combined)
    
    #Plot the velocity panel
    py.subplot(122)
    
    #Plot line
    py.plot(sim_utr,fit_function(sim_utr,best_params)*ref_velocity,color=colors[i],linestyle='-')
    
    #Plot half-load/half-value
    py.plot([Ks,Ks],[X_val*ref_velocity,0.0],color=colors[i],linestyle='--',linewidth=10)
    
    data_combined = np.hstack((np.vstack(sim_utr),np.vstack(fit_function(sim_utr,best_params))*ref_velocity))
    np.savetxt(header+'/'+proteins[i]+'_MVIS_simulated.txt',data_combined)

py.subplot(121)
py.legend(lines,proteins,loc=1)
py.xlim([min_load_plot,max_load_plot])
py.ylim([0,100])
py.ylabel('% Time mobile')
py.xlabel('Utrophin (nM)')

py.subplot(122)
py.legend(lines,proteins,loc=1)
py.xlim([min_load_plot,max_load_plot])
py.ylabel('MVIS (nm/s)')
py.xlabel('Utrophin (nM)')

py.savefig(header+'/'+'_vs_'.join(proteins)+fit_function_name+tail+'.png',dpi=100)
py.close()
    
#Filament length - avg
py.figure(1,figsize=(x_plot,y_plot))
lines = []
valid      = np.nonzero(~np.isnan(mean_data[:,-3]))[0]
max_length = np.max(mean_data[valid,-3])
for i in range(len(proteins)):
    valid_pro          = np.nonzero(protein_list == proteins[i])[0]
    
    mean_data_filtered = mean_data[valid_pro,:]
    std_data_filtered  = std_data[valid_pro,:]
    
    #utrophin concentration
    utr_conc           = utr_list[valid_pro]
    
    #Slide and experiment numbers
    slide_nums         = slide_list[valid_pro]
    exp_nums           = expnum_list[valid_pro]
    
    mean_length =  mean_data_filtered[:,-3]
    std_length  =  std_data_filtered[:,-3]
    
    #Combine data for in-place sorting
    full_data = np.array([[utr_conc[j],mean_length[j],std_length[j],slide_nums[j],exp_nums[j]] for j in range(len(utr_conc))],dtype='f8')
    full_data = full_data.view('f8,f8,f8,f8,f8')
    full_data.sort(order='f0',axis=0)
    
    utr             = full_data['f0'].flatten()
    mean_length     = full_data['f1'].flatten()
    std_length      = full_data['f2'].flatten()
    slide_nums      = full_data['f3'].flatten()
    exp_nums        = full_data['f4'].flatten()
    
    #Analyze only the data less than a utrophin concentration
    valid_utr       = np.nonzero((utr <= max_load_analysis)*(utr >= min_load_analysis))[0]
    utr             = utr[valid_utr]
    
    mean_length     = mean_length[valid_utr]
    std_length      = std_length[valid_utr]
    slide_nums      = slide_nums[valid_utr]
    exp_nums        = exp_nums[valid_utr]
    
    line = py.errorbar(utr, mean_length,yerr=std_length,marker='o',color=colors[i],linestyle='-',capsize=cap_size)
    lines.append(line[0])
    
    data_combined = np.hstack((np.vstack(utr),np.vstack(mean_length),np.vstack(std_length)))
    np.savetxt(header+'/'+proteins[i]+'_length.txt',data_combined)

#Plotting parameters
py.xlim([min_load_plot,max_load_analysis])
py.xlabel(load_molecule+'(nM)')
py.ylabel('Filament length(nm)')
py.legend(lines,proteins,loc=1)

py.savefig(header+'/'+'_vs_'.join(proteins)+'_length'+tail+'.png',dpi=200)
py.close()