#!/usr/bin/env python3

# ############################### #
# trainWoolf
#
# Script to build a woolf model.
#
# Anna Farrell-Sherman 11/5/18
# ############################### #

from argparse import ArgumentParser, ArgumentTypeError #used to create command line inputs
import sys # for error handeling

#import Woolf module for classifier building
from woolf import woolfClassifier

#function used to define the input of a range of numbers to the script
def parseRange(rangeString):
	'''Looks at inputed range from command line and parses it into a python object'''
	try:
		if ',' in rangeString:
			lo, hi = rangeString.split('-')
			hi, jump = hi.split(',')
			res = list(range(int(lo), int(hi)+1, int(jump)))
		elif '-' in rangeString:
			lo, hi = rangeString.split('-')
			res = list(range(int(lo), int(hi)+1))
		elif rangeString.isdigit():
			res = [int(rangeString)]
	except:
		raise ArgumentTypeError("'" + rangeString + "' is not a valid range of numbers. Expected formats: '1' '1-10' or '1-10,2'.")
	return res

parser = ArgumentParser(description="create a Woolf Model based a feature table")
parser.add_argument("featureTable", help="a CSV file containing feature data from a set of sequences")

#choose machine learning algorithm
group = parser.add_mutually_exclusive_group()
group.add_argument("-k", "--kNN", action="store_true")
group.add_argument("-f", "--randomForest", action="store_true")

#kNN arguments
parser.add_argument("-n", "--nNeighbors", default=range(1,20), type=parseRange, help="number of neighboors for kNN classifier, ranges are expresed as low-hi,jump (1-7,2 would test 1,3,5 and 7)")

#random Forest arguments
parser.add_argument("-t", "--nTrees", default=range(1,15,2), type=parseRange, help="number of trees for random forest classifier, ranges are expresed as low-hi,jump (1-7,2 would test 1,3,5 and 7)")
parser.add_argument("-l", "--minLeafSize", default=range(10,30,3), type=parseRange, help="minimum size of leaves in each tree of the random forest classifier, ranges are expresed as low-hi,jump (1-7,2 would test 1,3,5 and 7)")

#set up learning parameters
parser.add_argument("-s", "--featureScaler", default='MinMaxScaler', help="a scikit learn scaler object to scale in the input features")
parser.add_argument("-c", "--crossValidationFolds", default=5, help="the number of cross validation folds to execute")
parser.add_argument("-a", "--accuracyMetric", default='MCC', help="a scikit learn accuracy metric for training")

#add functionality
parser.add_argument("-p", "--predictFeatureTable", help="a unclassified feature table to be predicted by the model")
parser.add_argument("-e", "--listErrors", action="store_true", help="include to see a list of which sequences in the training dataset were missclassified")


#add verbose argument
parser.add_argument("-v", "--verbose", action="store_true", help="inlcude to get more detailed output")

args = parser.parse_args()

##################################################################################

#update learning parameters if needed
scalingType = args.featureScaler if args.featureScaler else scalingType
scoringMetric = args.accuracyMetric if args.accuracyMetric else scoringMetric
crossFolds = args.crossValidationFolds if args.crossValidationFolds else crossFolds
nNeighbors = args.nNeighbors if args.nNeighbors else nNeighbors
nTrees = args.nTrees if args.nTrees else nTrees
minLeafSize = args.minLeafSize if args.minLeafSize else minLeafSize

#Determine which classifier type was selected
if args.kNN:
	classifierType = 'kNN'
	print('Building kNN Woolf Model...')
elif args.randomForest:
	classifierType = 'fOREST'
	print('Building Random Forest Woolf Model...')
else:
	print('ERROR: Please choose either kNN (-k) or Random Forest (-f)')
	exit()

if args.verbose:
	print("Cross-validation Folds: " + str(crossFolds))
	print("Scoring Metric: " + str(scoringMetric))
	print("Scaler type: " + str(scalingType))

try:
	#Build the model
	woolf = woolfClassifier.buildWoolf(classifierType, scalingType, int(crossFolds), scoringMetric, nNeighbors, nTrees, minLeafSize)

	#Train the model
	print("Training Model...")
	results = woolfClassifier.trainModel(woolf, args.featureTable)
except (ValueError, TypeError) as e:
	print("Invalid Input: " + str(e))
	exit()


#OUTPUT
print("~~~~~~    RESULTS    ~~~~~~")
print("Score of best classifier: " + str(results.best_score_))
#print("Standard deviation of best score: " + str(results.cv_results_['std_test_score'][results.best_index_]))
print("Best Params:" + str(results.best_params_))
if args.verbose:
	print("Range of classifier scores across hyperparameters:")
	print("\tMax: " + str(results.best_score_))
	print("\tMin: " + str(min(results.cv_results_['mean_test_score'])))
	#print("\tParams: " + str(results.cv_results_['params']))
	#print("\tValues: " + str(results.cv_results_['mean_test_score']))
	#print("\tSD: " + str(results.cv_results_['std_test_score']))
	print("Range of training scores across hyperparameters:")
	print("\tMax: " + str(max(results.cv_results_['mean_train_score'])))
	print("\tMin: " + str(min(results.cv_results_['mean_train_score'])))
print("~~~~~~               ~~~~~~")

#Print misclassified sequences
if args.listErrors:
	print("Listing misclassified instances")
	posErrors, negErrors = woolfClassifier.findMisclassified(woolf, args.featureTable)
	print("misclassified as positive class:\n " + str(posErrors))
	print("misclassified as negative class:\n " + str(negErrors))

#Predict new instances
if args.predictFeatureTable:
	print("Predicting novel instances")
	try:
		prediction, score = woolfClassifier.predictWoolf(woolf, args.predictFeatureTable)
		print(prediction)
		if score:
			print("Score on test data: " + str(score))
	except FileNotFoundError:
		print("ERROR: Feature Table not found: " + args.predictFeatureTable)
