#!/usr/bin/env python

from pyPheWAS.pyPhewasCorev2 import *
import os
import time
import math
import argparse
import sys

def parse_args():
    parser = argparse.ArgumentParser(description="pyPheWAS ICD-Phecode Lookup Tool")

    parser.add_argument('--phenotype', required=True, type=str, help='Name of the phenotype file (e.g. icd9_data.csv)')
    parser.add_argument('--group', required=True, type=str, help ='Name of the group file (e.g. groups.csv)')
    parser.add_argument('--reg_type', required=True, type=str, help='Type of regression that you would like to use (log, lin, or dur)')
    parser.add_argument('--path', required=False, default='.', type=str, help='Path to all input files and destination of output files')
    parser.add_argument('--postfix', required=False, default=None, type=str, help='Descriptive postfix for output files (e.g. poster or ages50-60)')
    parser.add_argument('--phewas_cov', required=False, default=None, type=float, help='PheCodes to use as covariates in pyPhewasModel regression')
    parser.add_argument('--covariates', required=False, default='', type=str, help='Variables to be used as covariates')
    parser.add_argument('--response', required=False, default='', type=str, help='Variable to predict instead of genotype')
    parser.add_argument('--imbalance', required=False, default="True", type=str, help='Whether or not to show the direction of imbalance in the plot')
    parser.add_argument('--phewas_label', required=False, default="plot", type=str, help='Where to put PheCode labels - plot (default) or axis')
    parser.add_argument('--thresh_type', required=False, default=None, type=str, help='Type of threshold to be used in the plot (fdr, bon, or custom)')
    parser.add_argument('--custom_thresh', required=False, default=None, type=float, help='Custom threshold value (float between 0 and 1)')
    parser.add_argument('--plot_format', required=False, default="svgz", type=str, help='File extension for plot files (default: svgz)')
    
    args = parser.parse_args()
    return args

"""
Retrieve and validate all arguments.
"""
start = time.time()

args = parse_args()
kwargs = {'path': os.path.join(os.path.abspath(args.path),''),
          'phenotypefile': args.phenotype,
          'groupfile': args.group,
          'phewas_cov':args.phewas_cov,
          'postfix':args.postfix,
          'show_imbalance': eval(args.imbalance),
          'custom_thresh': args.custom_thresh,
          'covariates': args.covariates,
          'response': args.response,
          'phewas_label': args.phewas_label,
          'plot_format': args.plot_format
}

str_reg_type = args.reg_type
str_thresh_type = args.thresh_type

# Assert that a valid regression type was used
assert str_reg_type in regression_map.keys(), "%s is not a valid regression type" % str_reg_type
kwargs['reg_type'] = regression_map[str_reg_type]


# Assert that valid files were given
assert kwargs['phenotypefile'].endswith('.csv'), "%s is not a valid phenotype file, must be a .csv file" % (kwargs['phenotypefile'])
assert kwargs['groupfile'].endswith('.csv'), "%s is not a valid group file, must be a .csv file" % (kwargs['groupfile'])

# Assign the output file if none was assigned
if kwargs['postfix'] is None:
    if kwargs['covariates'] is not '':
        kwargs['postfix'] = kwargs['covariates'] + '_' + os.path.splitext(kwargs['groupfile'])[0]
    else:
        kwargs['postfix'] = os.path.splitext(kwargs['groupfile'])[0]
else:
    if kwargs['covariates'] is not '':
        kwargs['postfix'] = kwargs['covariates'] + '_' + os.path.splitext(kwargs['postfix'])[0]
    else:
        kwargs['postfix'] = os.path.splitext(kwargs['postfix'])[0]

# Check phewas_cov
if kwargs['phewas_cov']:
    kwargs['phewas_cov'] = float(kwargs['phewas_cov'])

if kwargs['response'] is None:
    kwargs['response'] = ""


# Assert that a valid threshold type was used
if args.thresh_type is None:
    kwargs['thresh_type'] = ['fdr','bon']
else:
    assert str_thresh_type in threshold_map.keys(), "%s is not a valid threshold type" % (str_thresh_type)
    kwargs['thresh_type'] = [str_thresh_type]

assert kwargs['phewas_label'] in ["plot","axis"], "%s is not a valid PheCode label location" % (kwargs['phewas_label'])

# Print Arguments
display_kwargs(kwargs)

# Make all arguments local variables
locals().update(kwargs)


""" 
pyPhewasLookup 
"""

print("Retrieving phenotype data...")
phenotypes = get_icd_codes(path, phenotypefile, reg_type)

print("Retrieving group data...")
genotypes = get_group_file(path, groupfile)

print("Generating feature matrix...")
fm,columns = generate_feature_matrix(genotypes,phenotypes,reg_type,phewas_cov)

fm_outfile = "feature_matrix_" + postfix + ".csv"
print("Saving feature matrices to %s" % (path + fm_outfile))
h = ','.join(columns)

np.savetxt(path + 'agg_measures_' + fm_outfile, fm[0],delimiter=',',header=h)
print("...")
np.savetxt(path + 'icd_age_' + fm_outfile, fm[1],delimiter=',',header=h)
print("...")
np.savetxt(path + 'phewas_cov_' + fm_outfile, fm[2],delimiter=',',header=h)


""" 
pyPhewasModel 
"""

print("Running PheWAS regressions...")
regressions = run_phewas(fm, genotypes, covariates, reg_type, response, phewas_cov)

reg_outfile = "regressions_" + postfix + ".csv"
print("Saving regression data to %s" % (path + reg_outfile))
header = ','.join(['str_reg_type', str_reg_type, 'group', groupfile]) + '\n'
f = open(os.sep.join([path, reg_outfile]), 'w')
f.write(header)
regressions.to_csv(f)
f.close()


""" 
pyPhewasPlot 
"""


# y = regressions['"-log(p)"']

try:
    regressions[['lowlim', 'uplim']] = regressions['Conf-interval beta'].str.split(',', expand=True)
    regressions['uplim'] = regressions.uplim.str.replace(']', '')
    regressions['lowlim'] = regressions.lowlim.str.replace('[', '')
    regressions = regressions.astype(dtype={'uplim': float, 'lowlim': float})
    #yb = regressions[['beta', 'lowlim', 'uplim']].values
    #yb = yb.astype(float)
except Exception as e:
    print('Error reading regression file:')
    print(e)
    sys.exit()

# Check if an imbalance will be used
#if show_imbalance:
#    imbalances = get_imbalances(regressions)
#else:
#    imbalances = np.array([])

# Get the regular p-values using a numpy vectorized function
#regpvalues = np.vectorize(lambda x: 10**(-x))(y)

pvalues = regressions['p-val'].values

for t in thresh_type:
    t_num = threshold_map[t]

    # Get the threshold type
    if t_num == 0:
        thresh = get_bon_thresh(pvalues,0.05)
    elif t_num == 1:
        thresh = get_fdr_thresh(pvalues,0.05)
    elif t_num == 2:
        thresh = kwargs['custom_thresh']
    print('%s threshold: %0.5f' % (t, thresh))

    plot_format.replace('.','') # remove leading periods if they were given by user
    save = path + t + '_'  + postfix + '.' + plot_format
    saveb = path + t + '_' + postfix + '_beta.' + plot_format
    print('Saving plots to %s' %save)

    plot_manhattan(regressions, -math.log10(thresh), show_imbalance, save, plot_format)
    plot_odds_ratio(regressions, -math.log10(thresh), show_imbalance, saveb, plot_format, phewas_label)

interval = time.time() - start
hour = math.floor(interval/3600.0)
minute = math.floor((interval - hour*3600)/60)
second = math.floor(interval - hour*3600 - minute*60)

if hour > 0:
    time_str = '%dh:%dm:%ds' %(hour,minute,second)
elif minute > 0:
    time_str = '%dm:%ds' % (minute, second)
else:
    time_str = '%ds' % second

print('pyPhewasPipeline Complete\nRuntime: %s' %time_str)
