#!/usr/bin/env python

import datetime
import argparse
import multiprocessing
import itertools
import os
import random
import sys

from sshmm.sequence_container import readSequencesAndShapes
from sshmm.seqstructhmm import SeqStructHMM
from sshmm import seq_hmm
from sshmm.log import prepareLogger

def parseArguments(args):
    """Sets up the command-line parser and calls it on the command-line arguments to the program.

    arguments:
    args -- command-line arguments to the program"""
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description='Trains multiple Hidden Markov Models for the sequence-structure binding '
                                                 'preferences of a given set of RNA-binding proteins. The models are trained on '
                                                 'sequences and structures in FASTA format located in a given data directory.\n'
                                                 'During the training process, statistics about the models are printed '
                                                 'on stdout. In every iteration, the current model and a visualization '
                                                 'of the model are stored in the batch directory.\n'
                                                 'The training processes terminate when no significant progress has been '
                                                 'made for three iterations.')
    parser.add_argument('data_directory', type=str,
                        help='data directory; must contain the sequence files under fasta/<protein>/positive.fasta and structure files under <structure_type>/<protein>/positive.txt')
    parser.add_argument('proteins', type=str,
                        help='list of RNA-binding proteins to analyze (surrounded by quotation marks, separated by whitespace)')
    parser.add_argument('batch_directory', type=str,
                        help='directory for batch output')
    parser.add_argument('--cores', '-c', type=int, help='number of cores to use (if not given, all cores are used)')
    parser.add_argument('--structure_type', '-s', type=str, default='shapes',
                        help='structure type to use; must match location of structure files (see data_directory argument above) (default: shapes)')
    parser.add_argument('--motif_length', '-n', type=int, default=6,
                        help='length of the motifs that shall be found (default: 6)')
    parser.add_argument('--baum_welch', '-b', action='store_true', default=True,
                        help='should the models be initialized with a Baum-Welch optimized sequence motif (default: yes)')
    parser.add_argument('--flexibility', '-f', type=int, default=10,
                        help='greedyness of Gibbs sampler: model parameters are sampled from among the top f configurations (default: f=10), set f to 0 in order to include all possible configurations')
    parser.add_argument('--block_size', type=int, default=1,
                        help='number of sequences to be held-out in each iteration (default: 1)')
    parser.add_argument('--threshold', '-t', type=float, default=10.0,
                        help='the iterative algorithm is terminated if this reduction in sequence structure '
                             'loglikelihood is not reached for any of the 3 last measurements (default: 10)')
    parser.add_argument('--termination_interval', '-i', type=int, default=100,
                        help='produce output every <i> iterations (default: i=100)')
    return parser.parse_args()


def do_training_for_configuration(data_directory, batch_job_directory, protein, structure_type, motif_length, baum_welch, flexibility, block_size, termination_interval, threshold):
    # Create output directory for job
    job_directory = '{0}/{1}_{2}_ml{3}_fl{4}_bs{5}_ti{6}_tt{7}/'.format(batch_job_directory, protein, structure_type, motif_length, flexibility, block_size, termination_interval, threshold)
    os.mkdir(job_directory)

    random.seed() #reinitialize the random number generator because else random numbers would repeat

    # Initialize logs
    main_logger = prepareLogger('main_logger', job_directory + '/verbose.log', verbose=True)
    numbers_logger = prepareLogger('numbers_logger', job_directory + '/numbers.log')

    # Read in sequences and shapes
    sequence_path = '{0}/fasta/{1}/positive.fasta'.format(data_directory, protein)
    structure_path = '{0}/{1}/{2}/positive.txt'.format(data_directory, structure_type, protein)
    training_sequence_container = readSequencesAndShapes(sequence_path, structure_path, motif_length, main_logger)
    main_logger.info('Read %s training sequences', training_sequence_container.get_length())

    # Initialize model
    model = SeqStructHMM(job_directory, main_logger, numbers_logger, training_sequence_container, motif_length,
                          flexibility, block_size)
    if baum_welch:
        best_baumwelch_sequence_model = seq_hmm.find_best_baumwelch_sequence_models(motif_length,
                                                                                     training_sequence_container,
                                                                                     main_logger)
        best_viterbi_paths = best_baumwelch_sequence_model[1]
        model.prepare_model_with_viterbi(best_viterbi_paths)
    else:
        model.prepare_model_randomly()

    # Train model
    key_statistics = model.do_training(termination_interval, threshold, write_model_state=True)[0]

    # Write results
    pwm_global = model.get_pwm_global()
    pwm_global.write_to_file(job_directory + 'logo_global.txt')
    pwm_global.write_weblogo(job_directory + 'logo_global.png')
    pwm_best_sequences = model.get_pwm_best_sequences()
    pwm_best_sequences.write_to_file(job_directory + 'logo_best_sequences.txt')
    pwm_best_sequences.write_weblogo(job_directory + 'logo_best_sequences.png')
    pwm_hairpin = model.get_pwm_hairpin()
    pwm_hairpin.write_to_file(job_directory + 'logo_hairpin.txt')
    pwm_hairpin.write_weblogo(job_directory + 'logo_hairpin.png')

    model.model.printAsGraph(job_directory + 'final_graph.png', model.sequence_container.get_length())
    model.model.write(job_directory + 'final_model.xml')
    return [protein] + key_statistics


def do_training_for_configuration_star(argsAsList):
    return do_training_for_configuration(*argsAsList)

def start_process(batchLogger):
    batchLogger.info('Starting ' + multiprocessing.current_process().name)

def main(args):
    options = parseArguments(args)

    #Create output directory for job
    batch_job_directory = "{0}/batch_{1}".format(options.batch_directory, datetime.datetime.now().strftime('%y%m%d_%H%M%S'))
    os.mkdir(batch_job_directory)

    #Initialize log
    batch_logger = prepareLogger('batch_logger', batch_job_directory + '/batch.log', verbose=True, stdout=True)

    #Generate jobs
    proteins = options.proteins.split()
    jobs = itertools.product([options.data_directory], [batch_job_directory], proteins, [options.structure_type],
                                              [options.motif_length], [True], [options.flexibility],
                                              [options.block_size], [options.termination_interval], [options.threshold])
    batch_logger.info('Started batch run on {0} proteins.'.format(len(proteins)))

    #Determine number of cores to use
    if options.cores:
        pool_size = options.cores
    else:
        pool_size = min(multiprocessing.cpu_count() - 1, len(proteins))

    batch_logger.info('Using {0} cores.'.format(pool_size))
    pool = multiprocessing.Pool(processes=pool_size,
                                initializer=start_process,
                                initargs=[batch_logger],
                                maxtasksperchild=2,
                                )

    #Execute jobs
    results = pool.map_async(do_training_for_configuration_star, jobs).get(9999999)
    pool.close() # no more tasks
    pool.join()  # wrap up current tasks

    batch_logger.info('Finished all proteins.')
    batch_logger.info('Result statistics:')
    batch_logger.info('\t'.join(['Protein', 'Number of iterations', 'Sequence loglikelihood', 'Sequence-structure loglikelihood']))
    for configuration in results:
        batch_logger.info('\t'.join(map(str, configuration)))

if __name__ == "__main__" :
    sys.exit(main(sys.argv))
