#!/usr/bin/env python

import datetime
import argparse
import multiprocessing
import itertools
import os
import random
import sys

from sshmm.sequence_container import readSequencesAndShapes
from sshmm.seqstructhmm import SeqStructHMM
from sshmm import seq_hmm
from sshmm.log import prepareLogger

def parseArguments(args):
    """Sets up the command-line parser and calls it on the command-line arguments to the program.

    arguments:
    args -- command-line arguments to the program"""
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description='Trains a batch of Hidden Markov Model for the sequence-structure binding '
                                                 'preferences of a given set of RNA-binding protein. The models are trained on '
                                                 'sequences and structures in FASTA format located in a given data directory.\n'
                                                 'During the training process, statistics about the models are printed '
                                                 'on stdout. In every iteration, the current model and a visualization '
                                                 'of the model are stored in the batch directory.\n'
                                                 'The training process terminates when no significant progress has been '
                                                 'made for three iterations.')
    parser.add_argument('data_directory', type=str,
                        help='data directory (must have the following subdirectories: fasta/, shapes/, structures/')
    parser.add_argument('proteins', type=str,
                        help='list of RNA-binding proteins to analyze (surrounded by quotation marks)')
    parser.add_argument('batch_directory', type=str,
                        help='directory for batch output')
    parser.add_argument('--cores', type=int, help='number of cores to use (if not given, all cores are used)')
    return parser.parse_args()


def do_training_for_configuration(data_directory, batch_directory, protein, structure_type, motif_length, flexibility, sample_size, output_interval, termination_criterion):
    job_directory = '{0}/jobs/{1}_{2}_ml{3}_fl{4}_ss{5}_oi{6}_tc{7}'.format(batch_directory, protein, structure_type, motif_length, flexibility, sample_size, output_interval, termination_criterion)
    os.mkdir(job_directory)

    #see http://stackoverflow.com/questions/9209078/using-python-multiprocessing-with-different-random-seed-for-each-process
    random.seed() #reinitialize the random number generator because else random numbers would repeat

    main_logger = prepareLogger('main_logger', job_directory + '/verbose.log', verbose=True)
    numbers_logger = prepareLogger('numbers_logger', job_directory + '/numbers.log')

    sequence_path = '{0}/fasta/{1}/positive.fasta'.format(data_directory, protein)
    structure_path = '{0}/{1}/{2}/positive.txt'.format(data_directory, structure_type, protein)
    training_sequence_container = readSequencesAndShapes(sequence_path, structure_path, motif_length, main_logger)
    main_logger.info('Read %s training sequences', training_sequence_container.get_length())

    best_baumwelch_sequence_models = seq_hmm.find_best_baumwelch_sequence_models(motif_length, training_sequence_container, main_logger)
    best_viterbi_paths = best_baumwelch_sequence_models[1]

    model0 = SeqStructHMM(0, job_directory, main_logger, numbers_logger, training_sequence_container, motif_length, flexibility, sample_size)
    model0.prepare_model_with_viterbi(best_viterbi_paths)

    model0.do_training(output_interval, termination_criterion, write_model_state=True)

    pwm_global = model0.get_pwm_global()
    pwm_global.write_to_file(job_directory + 'recovered_pwm_global.txt')
    pwm_global.write_weblogo(job_directory + 'recovered_pwm_global.png')

    pwm_best_sequences = model0.get_pwm_best_sequences()
    pwm_best_sequences.write_to_file(job_directory + 'recovered_pwm_best_sequences.txt')
    pwm_best_sequences.write_weblogo(job_directory + 'recovered_pwm_best_sequences.png')

    model0.model.printAsGraph(job_directory + 'model_at_end.png', model0.sequence_container.get_length())
    model0.model.write(job_directory + 'model_at_end.xml')
    return [protein, structure_type, motif_length, flexibility, sample_size, output_interval, termination_criterion] + key_statistics


def do_training_for_configuration_star(argsAsList):
    return do_training_for_configuration(*argsAsList)

def start_process(batchLogger):
    batchLogger.info('Starting ' + multiprocessing.current_process().name)

def main(args):
    options = parseArguments(args)

    #Create output directory for job
    batch_directory = "{0}/{1}".format(options.batch_directory, datetime.datetime.now().strftime('%y%m%d_%H%M%S'))
    os.mkdir(batch_directory)
    os.mkdir(batch_directory + '/jobs')

    batch_logger = prepareLogger('batch_logger', batch_directory + '/batch.log', verbose=True, stdout=True)

    if options.cores:
        pool_size = options.cores
    else:
        pool_size = multiprocessing.cpu_count() - 1
    pool = multiprocessing.Pool(processes=pool_size,
                                initializer=start_process,
                                initargs=[batch_logger],
                                maxtasksperchild=2,
                                )

    proteins = options.proteins.split()
    motif_lengths = [4,5,6,7,8]
    structure_types = ['shapes']
    flexibilities = [0, 10]
    sample_sizes = [1]
    output_interval = [100]
    termination_criterion = [10]

    parameterCombinations = itertools.product([options.data_directory], [batch_directory], proteins, structure_types,
                                              motif_lengths, flexibilities, sample_sizes, output_interval, termination_criterion)
    results = pool.map_async(do_training_for_configuration_star, parameterCombinations).get(9999999)
    pool.close() # no more tasks
    pool.join()  # wrap up current tasks

    batch_logger.info('All jobs done.')
    batch_logger.info('Results:')
    batch_logger.info('protein,structure_type,motif_length,flexibility,sample_size,output_interval,termination_criterion,iteration_number,sequence_logll,sequence_structure_logll')
    for configuration in results:
        batch_logger.info(','.join(map(str, configuration)))

if __name__ == "__main__" :
    sys.exit(main(sys.argv))