#!/usr/bin/env python

import datetime
import argparse
import sys
import os
from collections import Counter

from sshmm import seq_hmm
from sshmm.seqstructhmm import SeqStructHMM
from sshmm.sequence_container import readSequencesAndShapes
from sshmm.log import prepareLogger

def parseArguments(args):
    """Sets up the command-line parser and calls it on the command-line arguments to the program.

    arguments:
    args -- command-line arguments to the program"""
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description='Trains a Hidden Markov Model for the sequence-structure binding '
                                                 'preferences of a RNA-binding protein. The model is trained on '
                                                 'sequences and structures from a CLIP-seq experiment given in '
                                                 'two FASTA-like files.\n'
                                                 'During the training process, statistics about the model are printed '
                                                 'on stdout. In every iteration, the current model and a visualization '
                                                 'of the model are stored in the output directory.\n'
                                                 'The training process terminates when no significant progress has been '
                                                 'made for three iterations.')
    parser.add_argument('training_sequences', type=str,
                        help='FASTA file storing the training sequences')
    parser.add_argument('training_shapes', type=str,
                        help='FASTA file storing the training RNA shapes')
    parser.add_argument('--motif_length', '-n', type=int, default = 8,
                        help='length of the motif that shall be found (default: 8)')
    parser.add_argument('--baum_welch', '-b', action='store_true',
                        help='should the model be initialized with a Baum-Welch optimized sequence motif (default: no)')
    parser.add_argument('--flexibility', '-f', type=int, default = 0,
                        help='number of top configurations that configuration to set shall be randomly drawn from (default: 0), set to 0 in order to include all possible configurations')
    parser.add_argument('--sample_size', '-s', type=int, default = 1,
                        help='number of sequences to be sampled each iteration (default: 1)')
    parser.add_argument('--termination', '-t', type=float, default = 50.0,
                        help='the iterative algorithm is terminated if this reduction in sequence structure '
                        'loglikelihood is not reached for any of the 3 last measurements')
    parser.add_argument('--job_name', '-j', type=str, default="job",
                        help='name of the job (default: "job")')
    parser.add_argument('--output_directory', '-o', type=str, default=".",
                        help='directory to write output files to (default: current directory)')
    parser.add_argument('--output_interval', '-i', type=int, default = 1000,
                        help='produce output every <i> iterations (default: i=1000)')
    parser.add_argument('--write_model_state', '-w', action='store_true',
                        help='write model state every i iterations')
    parser.add_argument('--only_best_shape', action='store_true',
                        help='train only using best shape for each sequence')
    return parser.parse_args()


def main(args):
    options = parseArguments(args)
    
    #Create output directory for job
    job_directory = "{0}/{1}_{2}/".format(options.output_directory, options.job_name, datetime.datetime.now().strftime('%y%m%d_%H%M%S'))
    os.mkdir(job_directory)

    main_logger = prepareLogger('main_logger', job_directory + options.job_name + '_verbose.log', verbose=True, stdout=True)
    numbers_logger = prepareLogger('numbers_logger', job_directory + options.job_name + '_numbers.log')

    main_logger.info("Call: %s", " ".join(args))
    main_logger.info("Chosen options:")
    main_logger.info("Motif Length: %s", options.motif_length)
    main_logger.info("Baum-Welch initialization: %s", "on" if options.baum_welch else "off")
    main_logger.info("Flexibility: top %s configurations", options.flexibility)
    main_logger.info("Sample size: %s", options.sample_size)
    main_logger.info("Termination: %s", options.termination)
    main_logger.info("Job name: %s", options.job_name)
    main_logger.info("Output directory: %s", options.output_directory)
    main_logger.info("Output interval: %s iterations", options.output_interval)

    #Read in sequences and shapes
    training_sequence_container = readSequencesAndShapes(options.training_sequences, options.training_shapes, options.motif_length, main_logger, options.only_best_shape)
    main_logger.info('Read %s training sequences', training_sequence_container.get_length())

    model0 = SeqStructHMM(0, job_directory, main_logger, numbers_logger, training_sequence_container, options.motif_length, options.flexibility, options.sample_size)

    if options.baum_welch:
        best_baumwelch_sequence_model = seq_hmm.find_best_baumwelch_sequence_models(options.motif_length, training_sequence_container, main_logger)
        best_viterbi_paths = best_baumwelch_sequence_model[1]

        model0.prepare_model_with_viterbi(best_viterbi_paths)
    else:
        model0.prepare_model_randomly()

    main_logger.info('Completed initialisation. Begin training.')
    model0.do_training(options.output_interval, options.termination, options.write_model_state)
    main_logger.info('Completed training. Write PWMs.')
    pwm_global = model0.get_pwm_global()
    pwm_global.write_to_file(job_directory + 'recovered_pwm_global.txt')
    pwm_global.write_weblogo(job_directory + 'recovered_pwm_global.png')
    pwm_best_sequences = model0.get_pwm_best_sequences()
    pwm_best_sequences.write_to_file(job_directory + 'recovered_pwm_best_sequences.txt')
    pwm_best_sequences.write_weblogo(job_directory + 'recovered_pwm_best_sequences.png')
    pwm_hairpin = model0.get_pwm_hairpin()
    pwm_hairpin.write_to_file(job_directory + 'recovered_pwm_hairpin.txt')
    pwm_hairpin.write_weblogo(job_directory + 'recovered_pwm_hairpin.png')

    main_logger.info('Completed writing PWMs. Print model graph.')
    graph_path = job_directory + 'model_at_end.png'
    model0.model.printAsGraph(graph_path, model0.sequence_container.get_length())
    main_logger.info('Printed model graph: {0} Write model XML.'.format(graph_path))
    xml_path = job_directory + 'model_at_end.xml'
    model0.model.write(xml_path)
    main_logger.info('Wrote model XML:' + xml_path)

if __name__ == "__main__" :
    sys.exit(main(sys.argv))
