#!/usr/bin/env python

from pyPheWAS.pyPhewasCore import *
import sys, os

optargs = {
	'--statfile': 'statfile',
	'--imbalance': 'show_imbalance',
	'--path':'path',
	'--thresh_type':'thresh_type',
	'--outfile':'outfile'
}

args = sys.argv[1:]

# Define default arguments
kwargs = {'outfile' : None, 'path':'.'}

kwargs = process_args(kwargs, optargs, *args)

# Change path to absolute path
kwargs['path'] = os.path.join(os.path.abspath(kwargs['path']),'')

# Assert that a valid threshold type was used
assert kwargs['thresh_type'] in threshold_map.keys(), "%s is not a valid regression type" % (kwargs['thresh_type'])
str_thresh_type = kwargs['thresh_type']
kwargs['thresh_type'] = threshold_map[kwargs['thresh_type']]

# Evaluate the imbalance string
kwargs['show_imbalance'] = eval(kwargs['show_imbalance'])

# Assert that valid files were given
assert kwargs['statfile'].endswith('.csv'), "%s is not a valid phenotype file, must be a .csv file" % (kwargs['feature_matrix'])

# Specify no output if output file was not given
if kwargs['outfile'] == None:
	kwargs['outfile'] = ''
else:
	assert kwargs['outfile'].endswith('.png') or kwargs['outfile'].endswith('.pdf'), "%s is not a valid plot file, must be a .png or a .pdf" % (kwargs['outfile'])

ff = open(kwargs['path'] + kwargs['statfile'])
header = ff.readline().strip().split(',')
for i in range(0,len(header),2):
	kwargs[header[i]] = header[i+1]

# Print Arguments
display_kwargs(kwargs)

# Make all arguments local variables
locals().update(kwargs)

# Read in the remaining data (the pandas DataFrame)
regressions = pd.read_csv(ff)

y = regressions['"-log(p)"']

# Check if an imbalanace will be used
if show_imbalance:
	imbalances = get_imbalances(regressions)
else:
	imbalances = np.array([])

# Get the regular p-values using a numpy vectorized function
regpvalues = np.vectorize(lambda x: 10**(-x))(y)

# Get the threshold type
if thresh_type == 0:
		thresh = get_bon_thresh(y,0.05)
elif thresh_type == 1:
		thresh = get_fdr_thresh(regpvalues,0.05)

if outfile != '':
	save = path + outfile
	print("Saving plot to %s" % (save))
else:
	save = ''
	print("Displaying plot.")

plot_data_points(y, -math.log10(thresh), save, imbalances)
