#!/usr/bin/env python

import sys
import argparse
import threading
import os
from subprocess import call, check_call, CalledProcessError
from itertools import izip
from sshmm.structure_prediction import calculate_rna_shapes_from_file, calculate_rna_structures_from_file

#Class defining a thread to run secondary structure prediction
class FoldThread(threading.Thread):
     def __init__(self, script, fasta, profile):
         super(FoldThread, self).__init__()
         self.daemon = True
         self.script = script
         self.fasta = open(fasta, 'r')
         self.profile = open(profile, 'w')

     def run(self):
         call([self.script, '-W', '240', '-L', '160', '-u', '1'], stdin=self.fasta, stdout=self.profile)

def parseArguments(args):
    """Sets up the command-line parser and calls it on the command-line arguments to the program.

    arguments:
    args -- command-line arguments to the program"""
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description="""Prepare a CLIP-Seq dataset in BED format for the training of an ssHMM. The following preprocessing steps are taken:
1 - Filter (positive) BED file, 2 - Shuffle (positive) BED file to generate negative dataset, 3 - Elongate positive and negative BED files for later structure prediction, 4 - Fetch genomic sequences for elongated BED files, 5 - Produce FASTA files with genomic sequences in viewpoint format, 6 - Calculate RNA shapes, 7 - Calculate RNA structures

This script requires awk, bedtools (shuffle, slop, getfasta), RNAshapes, and RNAstructures.

A root directory for the datasets and a dataset name (e.g., the protein name) has to be given. The following files will be created in the root directory and its subdirectories:
- /bed/<dataset_name>/positive_raw.bed - positive BED file from CLIP-Seq experiment
- /bed/<dataset_name>/positive.bed - filtered positive BED file
- /bed/<dataset_name>/negative.bed - filtered negative BED file
- /bed/<dataset_name>/positive_long.bed - elongated positive BED file
- /bed/<dataset_name>/negative_long.bed - elongated negative BED file
- /temp/<dataset_name>/positive_long.fasta - genomic sequences of elongated positive BED file
- /temp/<dataset_name>/negative_long.fasta - genomic sequences of elongated negative BED file
- /fasta/<dataset_name>/positive.fasta - positive genomic sequences in viewpoint format
- /fasta/<dataset_name>/negative.fasta - negative genomic sequences in viewpoint format
- /shapes/<dataset_name>/positive.txt - secondary structures of positive genomic sequence (predicted by RNAshapes)
- /shapes/<dataset_name>/negative.txt - secondary structures of negative genomic sequence (predicted by RNAshapes)
- /structures/<dataset_name>/positive.txt - secondary structures of positive genomic sequence (predicted by RNAstructures)
- /structures/<dataset_name>/negative.txt - secondary structures of negative genomic sequence (predicted by RNAstructures)

The preprocessing step to begin with can be chosen. For each step, the files generated by the previous step need to be present. To execute all steps, only the positive_raw.bed must be present.

For the filtering step, the minimum score and binding site lengths can be defined with parameters. To fetch genomic sequences in step 4, the following file must be present in the genomes/ subdirectory:

- /genomes/[version]/[version].genome -> BED file defining the size of the chromosomes
- /genomes/[version]/UCSCGenesTrack.bed -> BED file defining the gene intervals
- /genomes/[version]/[version].fa -> FASTA file containing the human genome

The version of the genome can be given as an optional parameter. It defaults to 'hg19'. The files for the genomes/ directory can be obtained from UCSC:

/genomes/[version]/[version].genome:
-> download from http://hgdownload.soe.ucsc.edu/downloads.html#human (Full data set), e.g. http://hgdownload.soe.ucsc.edu/goldenPath/hg19/bigZips/hg19.chrom.sizes

/genomes/[version]/UCSCGenesTrack.bed:
-> download in table browser (http://genome.ucsc.edu/cgi-bin/hgTables); choose most recent GENCODE track (currently GENCODE Gene V24lift37->Basic (for hg19) and All GENCODE V24->Basic (for hg38)) and 'BED' as output format

/genomes/[version]/[version].fa:
-> download chromosomes from http://hgdownload.soe.ucsc.edu/downloads.html; e.g. wget --timestamping 'ftp://hgdownload.cse.ucsc.edu/goldenPath/hg19/chromosomes/*'; concatenate chromosomes with cat and print into .fa file (e.g. with 'zcat chr* > hg19.fa')
""")
    parser.add_argument('directory', type=str, help='root directory for data')
    parser.add_argument('dataset_name', type=str, help='dataset name')
    parser.add_argument('jump_to', type=int, default=1, help='preprocessing step to jump to (as integer): '
                                                             '1 - filter bed, 2 - shuffle bed, '
                                                             '3 - enlongate bed, 4 - fetch sequences, 5 - format FASTA, '
                                                             '6 - calculate RNA shapes, 7 - calculate RNA structures')
    parser.add_argument('min_score', type=float, default=0.0, help='minimum score for binding site (default: 0.0)')
    parser.add_argument('--genome', '-g', type=str, default='hg19', help='genome version to use (default: hg19)')
    parser.add_argument('--min_length', type=int, default=8, help='minimum binding site length (default: 8)')
    parser.add_argument('--max_length', type=int, default=75, help='maximum binding site length (default: 75)')
    parser.add_argument('--elongation', '-e', type=int, default=20, help='span for up- and downstream elongation of binding sites (default: 20)')
    return parser.parse_args()

def checkPrereqs():
    devnull = open(os.devnull, 'w')
    try:
        check_call(["bedtools", "--version"], stdout=devnull, stderr=devnull)
    except OSError as e:
        if e.errno == os.errno.ENOENT:
            print>>sys.stderr, "ERROR: bedtools not found. Tried to execute 'bedtools --version'."
            return False
        else:
            # Something else went wrong while trying to run `bedtools`
            raise

    try:
        check_call(["which", "awk"], stdout=devnull, stderr=devnull)
    except CalledProcessError as e:
        print>>sys.stderr, "ERROR: awk not found. Tried to execute 'which awk'."
        return False

    try:
        check_call(["RNAshapes", "-v"], stdout=devnull, stderr=devnull)
    except OSError as e:
        if e.errno == os.errno.ENOENT:
            print>>sys.stderr, "ERROR: RNAshapes not found. Tried to execute 'RNAshapes -v'."
            return False
        else:
            # Something else went wrong while trying to run `RNAshapes`
            raise

    try:
        check_call(["Fold", "-v"], stdout=devnull, stderr=devnull)
        check_call(["ct2dot", "-v"], stdout=devnull, stderr=devnull)
    except OSError as e:
        if e.errno == os.errno.ENOENT:
            print>>sys.stderr, "ERROR: RNAstructures not found. Tried to execute 'Fold -v' and 'ct2dot -v'."
            return False
        else:
            # Something else went wrong while trying to run `RNAstructures`
            raise

    return True

def prepareFolderStructure(directory, proteinName):
    if not os.path.exists(directory):
        print>>sys.stderr, 'ERROR: Directory \'{0}\' not found.'.format(directory)
        return False
    if not os.path.exists(directory + '/bed/' + proteinName):
        os.makedirs(directory + '/bed/' + proteinName)
    if not os.path.exists(directory + '/fasta/' + proteinName):
        os.makedirs(directory + '/fasta/' + proteinName)
    if not os.path.exists(directory + '/temp/' + proteinName):
        os.makedirs(directory + '/temp/' + proteinName)
    if not os.path.exists(directory + '/shapes/' + proteinName):
        os.makedirs(directory + '/shapes/' + proteinName)
    if not os.path.exists(directory + '/structures/' + proteinName):
        os.makedirs(directory + '/structures/' + proteinName)
    return True

def filteringBed(filteredBedFileName, rawBedFileName, lowest_score, min_length, max_length):
    print>>sys.stderr, '###STEP 1: Filtering positive raw bed-file'
    print>>sys.stderr, 'INPUT:', rawBedFileName
    print>>sys.stderr, 'OUTPUT:', filteredBedFileName
    
    filteredBedFile = open(filteredBedFileName, 'w')
    #filter for max length, min length, and min score
    call(['awk', '{{if ($3-$2 <= {0} && $3-$2 >= {1} && $5 >= {2}) {{print $0}}}}'.format(max_length, min_length, lowest_score), rawBedFileName], stdout=filteredBedFile)
    filteredBedFile.close()
    num_filtered = sum(1 for line in open(filteredBedFileName))
    
    print>>sys.stderr, 'STEP 1 finished (filtering resulted in {0} good binding sites)'.format(num_filtered)

def shuffleBed(bedPositiveFileName, bedHg19FileName, genesFileName, bedNegativeFileName):
    print>>sys.stderr, '###STEP 2: Shuffling positive binding sites to obtain negative binding sites'
    print>>sys.stderr, 'INPUT:', bedPositiveFileName
    print>>sys.stderr, 'OUTPUT:', bedNegativeFileName  
    
    bedNegativeFile = open(bedNegativeFileName, 'w')
    call(['bedtools', 'shuffle', '-i', bedPositiveFileName, '-g', bedHg19FileName, '-incl', genesFileName], stdout=bedNegativeFile)
    bedNegativeFile.close()
    
    print>>sys.stderr, 'STEP 2 finished'

def elongatingBed(bedIntervalLongFileName, genomeFileName, bedIntervalFileName, elongation):
    bedIntervalLongFile = open(bedIntervalLongFileName, 'w')
    print>>sys.stderr, '###STEP 3: Elongating binding sites by', elongation, 'nt for structure prediction'
    print>>sys.stderr, 'INPUT:', bedIntervalFileName
    print>>sys.stderr, 'OUTPUT:', bedIntervalLongFileName
    
    call(['bedtools', 'slop', '-i', bedIntervalFileName, '-g', genomeFileName, '-b', str(elongation)], stdout=bedIntervalLongFile)
    bedIntervalLongFile.close()
    
    print>>sys.stderr, 'STEP 3 finished'

def fetchingSequences(genomeFastaFileName, fastaTempFileName, bedIntervalLongFileName):
    print>>sys.stderr, '###STEP 4: Fetching nucleotide sequence for elongated binding sites'
    print>>sys.stderr, 'INPUT:', bedIntervalLongFileName
    print>>sys.stderr, 'OUTPUT:', fastaTempFileName
    
    call(['fastaFromBed', '-s', '-fi', genomeFastaFileName, '-bed', bedIntervalLongFileName, '-fo', fastaTempFileName])
    
    print>>sys.stderr, 'STEP 4 finished'

def formatFasta(fastaFormattedFileName, bedIntervalFileName, bedIntervalLongFileName, fastaTempFileName, genome):
    print>>sys.stderr, '###STEP 5: Reformatting nucleotide sequences into viewpoint format (binding sites in uppercase, elongation in lowercase)'
    print>>sys.stderr, 'INPUT:', fastaTempFileName
    print>>sys.stderr, 'OUTPUT:', fastaFormattedFileName
    
    fastaTempFile = open(fastaTempFileName, 'r')
    fastaFormattedFile = open(fastaFormattedFileName, 'w')
    line = 0
    with open(bedIntervalFileName, 'r') as bedIntervalFile, open(bedIntervalLongFileName, 'r') as bedIntervalLongFile: 
        for x, y in izip(bedIntervalFile, bedIntervalLongFile):
            line += 1
            intervalFields = x.strip().split('\t')
            intervalLongFields = y.strip().split('\t')
            elongationUpstream = int(intervalFields[1]) - int(intervalLongFields[1])
            viewpointLength = int(intervalFields[2]) - int(intervalFields[1])
            elongationDownstream = int(intervalLongFields[2]) - int(intervalFields[2])
            totalBindingSiteLength = int(intervalLongFields[2]) - int(intervalLongFields[1])

            header = fastaTempFile.readline().strip()
            sequence = fastaTempFile.readline().strip()
            if totalBindingSiteLength <= 0:
                print>>sys.stderr, 'ERROR: Malformed elongated binding site (line {0})'.format(line)
                print>>sys.stderr, 'The start of the elongated binding site ({0}: {1}-{2}) is greater than its end'.format(
                    intervalLongFields[0], int(intervalLongFields[1]), int(intervalLongFields[2]))
                print>>sys.stderr, 'A common cause for this error is a mismatch in genome version between the bed-files and the \'genome\' parameter to this script. Check whether your binding sites are on {0}. If not, change the \'genome\' parameter accordingly.'.format(genome)
                
                cont = raw_input("Stop script (y/n)?")
                if cont == 'y':
                    return False
            elif totalBindingSiteLength != elongationUpstream + viewpointLength + elongationDownstream:
                print>>sys.stderr, 'ERROR: Binding sites in bed-files do not match correctly (line {0})'.format(line)
                print>>sys.stderr, 'The elongated binding site ({0}: {1}-{2}) does not contain the original binding site ({3}: {4}-{5})'.format(
                    intervalLongFields[0], int(intervalLongFields[1]), int(intervalLongFields[2]), int(intervalFields[0]), int(intervalFields[1]), int(intervalFields[2]))
                print>>sys.stderr, 'A common cause for this error is a mismatch in genome version between the bed-files and the \'genome\' parameter to this script. Check whether your binding sites are on {0}. If not, change the \'genome\' parameter accordingly.'.format(genome)
                
                cont = raw_input("Stop script (y/n)?")
                if cont == 'y':
                    return False
            elif len(sequence) != totalBindingSiteLength:
                print>>sys.stderr, 'ERROR: Binding site in bed-files does not match with the corresponding nucleotide sequence (line {0} in bed-files, sequence \'{1}\' in nucleotide sequence file)'.format(line, sequence)
                print>>sys.stderr, 'The length of the elongated binding site ({0} nt) is unequal to the length of the nucleotide sequence ({1} nt)'.format(totalBindingSiteLength, len(sequence))
                
                cont = raw_input("Stop script (y/n)?")
                if cont == 'y':
                    return False
            else:
                formattedSequence = sequence[:elongationUpstream].lower() + sequence[elongationUpstream:elongationUpstream + viewpointLength].upper() + sequence[elongationUpstream + viewpointLength:].lower()
                print>>fastaFormattedFile, header
                print>>fastaFormattedFile, formattedSequence
    fastaFormattedFile.close()
    fastaTempFile.close()
    
    print>>sys.stderr, 'STEP 5 finished'
    
    return True


def main(args):
    options = parseArguments(args)

    #Check required programs bedtools, awk, RNAshapes, and RNAstructures
    if not checkPrereqs():
        return

    #Prepare folder structure
    if not(prepareFolderStructure(options.directory, options.dataset_name)):        
        return

    #Filtering bed
    bedFileName = options.directory + '/bed/' + options.dataset_name + '/positive_raw.bed'
    bedPositiveFileName = options.directory + '/bed/' + options.dataset_name + '/positive.bed'
    if options.jump_to <= 1:
        filteringBed(bedPositiveFileName, bedFileName, options.min_score, options.min_length, options.max_length)

    #Genome paths
    genomeFileName = '{0}/genomes/{1}/{1}.genome'.format(options.directory, options.genome)
    genesFileName = '{0}/genomes/{1}/UCSCGenesTrack.bed'.format(options.directory, options.genome)
    genomeFastaFileName = '{0}/genomes/{1}/{1}.fa'.format(options.directory, options.genome)

    #Shuffle positive bed to get negative bed
    bedNegativeFileName = options.directory + '/bed/' + options.dataset_name + '/negative.bed'
    if options.jump_to <= 2:
        shuffleBed(bedPositiveFileName, genomeFileName, genesFileName, bedNegativeFileName)

    #Elongate positive bed
    bedPositiveLongFileName = options.directory + '/bed/' + options.dataset_name + '/positive_long.bed'
    if options.jump_to <= 3:
        elongatingBed(bedPositiveLongFileName, genomeFileName, bedPositiveFileName, options.elongation)

    #Elongate negative bed
    bedNegativeLongFileName = options.directory + '/bed/' + options.dataset_name + '/negative_long.bed'
    if options.jump_to <= 3:
        elongatingBed(bedNegativeLongFileName, genomeFileName, bedNegativeFileName, options.elongation)

    #Fetch positive sequences
    fastaTempPositiveFileName = options.directory + '/temp/' + options.dataset_name + '/positive_long.fasta'
    if options.jump_to <= 4:
        fetchingSequences(genomeFastaFileName, fastaTempPositiveFileName, bedPositiveLongFileName)

    #Fetch negative sequences
    fastaTempNegativeFileName = options.directory + '/temp/' + options.dataset_name + '/negative_long.fasta'
    if options.jump_to <= 4:
        fetchingSequences(genomeFastaFileName, fastaTempNegativeFileName, bedNegativeLongFileName)

    #Format positive FASTA
    fastaPositiveFileName = options.directory + '/fasta/' + options.dataset_name + '/positive.fasta'
    if options.jump_to <= 5:
        if not formatFasta(fastaPositiveFileName, bedPositiveFileName, bedPositiveLongFileName, fastaTempPositiveFileName, options.genome):
            return

    #Format negative FASTA
    fastaNegativeFileName = options.directory + '/fasta/' + options.dataset_name + '/negative.fasta'
    if options.jump_to <= 5:
        if not formatFasta(fastaNegativeFileName, bedNegativeFileName, bedNegativeLongFileName, fastaTempNegativeFileName, options.genome):
            return

    #Calculate positive RNA shapes
    shapePositiveFileName = options.directory + '/shapes/' + options.dataset_name + '/positive.txt'
    if options.jump_to <= 6:
        print>>sys.stderr, '###STEP 6: Calculating RNAshapes'
        print>>sys.stderr, 'INPUT:', fastaPositiveFileName
        print>>sys.stderr, 'OUTPUT:', shapePositiveFileName
        
        calculate_rna_shapes_from_file(shapePositiveFileName, fastaPositiveFileName, 10)
        
        print>>sys.stderr, 'STEP 6 finished'

    #Calculate negative RNA shapes
    shapeNegativeFileName = options.directory + '/shapes/' + options.dataset_name + '/negative.txt'
    if options.jump_to <= 6:
        print>>sys.stderr, '###STEP 6: Calculating RNAshapes'
        print>>sys.stderr, 'INPUT:', fastaPositiveFileName
        print>>sys.stderr, 'OUTPUT:', shapePositiveFileName
        
        calculate_rna_shapes_from_file(shapeNegativeFileName, fastaNegativeFileName, 10)
        
        print>>sys.stderr, 'STEP 6 finished'

    #Calculate positive RNA structures
    structuresPositiveFileName = options.directory + '/structures/' + options.dataset_name + '/positive.txt'
    if options.jump_to <= 7:
        print>>sys.stderr, '###STEP 7: Calculating RNAstructures'
        print>>sys.stderr, 'INPUT:', fastaPositiveFileName
        print>>sys.stderr, 'OUTPUT:', structuresPositiveFileName
        
        calculate_rna_structures_from_file(structuresPositiveFileName, fastaPositiveFileName)
        
        print>>sys.stderr, 'STEP 7 finished'

    #Calculate negative RNA structures
    structuresNegativeFileName = options.directory + '/structures/' + options.dataset_name + '/negative.txt'
    if options.jump_to <= 7:
        print>>sys.stderr, '###STEP 7: Calculating RNAstructures'
        print>>sys.stderr, 'INPUT:', fastaPositiveFileName
        print>>sys.stderr, 'OUTPUT:', structuresNegativeFileName
        
        calculate_rna_structures_from_file(structuresNegativeFileName, fastaNegativeFileName)
        
        print>>sys.stderr, 'STEP 7 finished'

if __name__ == "__main__" :
    sys.exit(main(sys.argv))
