#!/usr/bin/env python

import datetime
import argparse
import sys
import os

from sshmm import seq_hmm
from sshmm.seqstructhmm import SeqStructHMM
from sshmm.sequence_container import readSequencesAndShapes
from sshmm.log import prepareLogger

def parseArguments(args):
    """Sets up the command-line parser and calls it on the command-line arguments to the program.

    arguments:
    args -- command-line arguments to the program"""
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description='Trains a Hidden Markov Model for the sequence-structure binding '
                                                 'preferences of an RNA-binding protein. The model is trained on '
                                                 'sequences and structures from a CLIP-seq experiment given in '
                                                 'two FASTA-like files.\n'
                                                 'During the training process, statistics about the model are printed '
                                                 'on stdout. In every iteration, the current model and a visualization '
                                                 'of the model can be stored in the output directory.\n'
                                                 'The training process terminates when no significant progress has been '
                                                 'made for three iterations.')
    parser.add_argument('training_sequences', type=str,
                        help='FASTA file with sequences for training')
    parser.add_argument('training_structures', type=str,
                        help='FASTA file with RNA structures for training')
    parser.add_argument('--motif_length', '-n', type=int, default = 6,
                        help='length of the motif that shall be found (default: 6)')
    parser.add_argument('--baum_welch', '-b', action='store_true', default = True,
                        help='should the model be initialized with a Baum-Welch optimized sequence motif (default: yes)')
    parser.add_argument('--flexibility', '-f', type=int, default = 10,
                        help='greedyness of Gibbs sampler: model parameters are sampled from among the top f configurations (default: f=10), set f to 0 in order to include all possible configurations')
    parser.add_argument('--block_size', '-s', type=int, default = 1,
                        help='number of sequences to be held-out in each iteration (default: 1)')
    parser.add_argument('--threshold', '-t', type=float, default = 10.0,
                        help='the iterative algorithm is terminated if this reduction in sequence structure '
                        'loglikelihood is not reached for any of the 3 last measurements (default: 10)')
    parser.add_argument('--job_name', '-j', type=str, default="job",
                        help='name of the job (default: "job")')
    parser.add_argument('--output_directory', '-o', type=str, default=".",
                        help='directory to write output files to (default: current directory)')
    parser.add_argument('--termination_interval', '-i', type=int, default = 100,
                        help='produce output every <i> iterations (default: i=100)')
    parser.add_argument('--no_model_state', '-w', action='store_true',
                        help='do not write model state every i iterations')
    parser.add_argument('--only_best_shape', action='store_true',
                        help='train only using best structure for each sequence (default: use all structures)')
    return parser.parse_args()


def main(args):
    options = parseArguments(args)
    
    #Create output directory for job
    job_directory = "{0}/{1}_{2}/".format(options.output_directory, options.job_name, datetime.datetime.now().strftime('%y%m%d_%H%M%S'))
    os.mkdir(job_directory)

    #Initialize logs
    main_logger = prepareLogger('main_logger', job_directory + options.job_name + '_verbose.log', verbose=True, stdout=True)
    numbers_logger = prepareLogger('numbers_logger', job_directory + options.job_name + '_numbers.log')
    main_logger.info("Call: %s", " ".join(args))
    main_logger.info("Chosen options:")
    main_logger.info("Motif Length: %s", options.motif_length)
    main_logger.info("Baum-Welch initialization: %s", "on" if options.baum_welch else "off")
    main_logger.info("Flexibility: top %s configurations", options.flexibility)
    main_logger.info("Block size: %s", options.block_size)
    main_logger.info("Termination threshold: %s", options.threshold)
    main_logger.info("Job name: %s", options.job_name)
    main_logger.info("Output directory: %s", options.output_directory)
    main_logger.info("Termination interval: %s iterations", options.termination_interval)
    main_logger.info("Write model state: %s", "off" if options.no_model_state else "on")
    main_logger.info("Structures: %s", "only the best of each sequence" if options.only_best_shape else "all")

    #Read in sequences and shapes
    training_sequence_container = readSequencesAndShapes(options.training_sequences, options.training_structures, options.motif_length, main_logger, options.only_best_shape)
    main_logger.info('Finished reading %s training sequences', training_sequence_container.get_length())

    #Initialize model
    model = SeqStructHMM(job_directory, main_logger, numbers_logger, training_sequence_container, options.motif_length,
                          options.flexibility, options.block_size)
    if options.baum_welch:
        best_baumwelch_sequence_model = seq_hmm.find_best_baumwelch_sequence_models(options.motif_length, training_sequence_container, main_logger)
        best_viterbi_paths = best_baumwelch_sequence_model[1]
        model.prepare_model_with_viterbi(best_viterbi_paths)
    else:
        model.prepare_model_randomly()
    main_logger.info('Completed initialisation. Begin training..')

    #Train model
    model.do_training(options.termination_interval, options.threshold, not options.no_model_state)
    main_logger.info('Completed training. Write sequence logos..')

    #Write results
    pwm_global = model.get_pwm_global()
    pwm_global.write_to_file(job_directory + 'logo_global.txt')
    pwm_global.write_weblogo(job_directory + 'logo_global.png')
    pwm_best_sequences = model.get_pwm_best_sequences()
    pwm_best_sequences.write_to_file(job_directory + 'logo_best_sequences.txt')
    pwm_best_sequences.write_weblogo(job_directory + 'logo_best_sequences.png')
    pwm_hairpin = model.get_pwm_hairpin()
    pwm_hairpin.write_to_file(job_directory + 'logo_hairpin.txt')
    pwm_hairpin.write_weblogo(job_directory + 'logo_hairpin.png')
    main_logger.info('Completed writing sequence logos. Print model graph..')

    graph_path = job_directory + 'final_graph.png'
    model.model.printAsGraph(graph_path, model.sequence_container.get_length())
    main_logger.info('Printed model graph: {0}. Write model file..'.format(graph_path))

    xml_path = job_directory + 'final_model.xml'
    model.model.write(xml_path)
    main_logger.info('Wrote model file: ' + xml_path)
    main_logger.info('Finished ssHMM successfully.')

if __name__ == "__main__" :
    sys.exit(main(sys.argv))
