#! /usr/bin/env python
"""
    Copyright (c) <2015> <Kristoffer Sahlin>



    Permission is hereby granted, free of charge, to any person obtaining a copy
    of this software and associated documentation files (the "Software"), to deal
    in the Software without restriction, including without limitation the rights
    to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    copies of the Software, and to permit persons to whom the Software is
    furnished to do so, subject to the following conditions:



    The above copyright notice and this permission notice shall be included in
    all copies or substantial portions of the Software.



    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
    AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    THE SOFTWARE.


"""

from __future__ import print_function
import os
import shutil
import sys
import copy
import argparse
import signal
from time import time
import re

from multiprocessing import Pool
import multiprocessing as mp

from modules import isocon_get_candidates
from modules import isocon_statistical_test
from modules import isocon_parameters
from modules.input_output import write_output, fasta_parser, fastq_parser
from modules import partitions
# from modules import identify_barcodes

def clean_dir(params):
    keep_files = set(["logfile.txt", "candidates_converged.fa", "final_candidates.fa", "not_converged.fa", "cluster_info.tsv"])
    for the_file in os.listdir(params.outfolder):
        if the_file in keep_files:
            continue

        file_path = os.path.join(params.outfolder, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
            elif os.path.isdir(file_path): 
                shutil.rmtree(file_path)
        except Exception as e:
            print(e)

def run_remove_barcodes():
    pass

def run_get_candidates(params):
    total_start = time()
    initialize_logger(params)
    params.read_file = params.fl_reads
    # if params.barcodes:
    #     # need to redirect path to reads to the new path with modified reads as well
    #     filtered_file = identify_barcodes.remove_barcodes_fcn(params.read_file, params)
    #     params.read_file = filtered_file
    candidates_file_name, read_partition, to_realign  = isocon_get_candidates.find_candidate_transcripts(params.read_file, params)
    total_elapsed = time() - total_start
    write_output.logger('TOTAL TIME ELAPSED FOR nearest_neighbor APPROACH:{0}'.format(str(total_elapsed)), params.logfile)

    if params.cleanup:
        clean_dir(params)

def run_stat_filter(params):
    total_start = time()
    initialize_logger(params)
    all_candidates = {acc: seq for (acc, seq) in  fasta_parser.read_fasta(open(params.candidates, 'r'))}
    # candidate_alignments =  alignment_file_parser.get_alignments(open(params.candidate_alignments, 'r'))
    print("nr initial candidates:", len(all_candidates))
    preprocessed_candidates_file = open(os.path.join(params.outfolder, "preprocessed_candidates.fa"), "w")
    counter = 0
    for acc, seq in all_candidates.items():
        match = re.search('support_(\d+)', acc)
        if match:
            support = int(match.group(1))
            if support < params.min_candidate_support:
                print("filter candidate with support:", support)
                continue
        else:
            print("no support in accession")

        preprocessed_candidates_file.write(">{0}\n{1}\n".format(acc, seq))
        counter += 1

    preprocessed_candidates_file.close()
    params.candidates = preprocessed_candidates_file.name
    print("After preprocessed:", counter)

    if os.stat(params.candidates).st_size == 0:
        out_file_name = os.path.join(params.outfolder, "final_candidates.fa")
        write_output.print_candidates(out_file_name, {}, {}, {}, {}, params, final = True)
        print("Candidate file is empty!")
        sys.exit(0)

    if params.is_fastq:
        to_realign = {acc: seq for (acc, seq, qual) in  fastq_parser.readfq(open(params.fl_reads, 'r'))}
    else:
        to_realign = {acc: seq for (acc, seq) in  fasta_parser.read_fasta(open(params.fl_reads, 'r'))}
    # to_realign = {acc: seq for (acc, seq) in  fasta_parser.read_fasta(open(params.fl_reads, 'r'))}
    # filtered_candidates = {acc: seq for (acc, seq) in  fasta_parser.read_fasta(open(params.candidates, 'r'))}
    # G_star, partition_of_X =  partitions.partition_strings_2set(original_reads, filtered_candidates, params.fl_reads, params.candidates, params)
    read_partition = {acc: {} for (acc, seq) in  fasta_parser.read_fasta(open(params.candidates, 'r'))}
    # print("transcript_90_support_2" in partition_of_X)
    
    isocon_statistical_test.stat_filter_candidates(params.fl_reads, params.candidates, read_partition, to_realign, params)
    total_elapsed = time() - total_start
    if params.cleanup:
        clean_dir(params)
    write_output.logger('TOTAL TIME ELAPSED FOR STAT_FILTER:{0}'.format(str(total_elapsed)), params.logfile)

def initialize_logger(params):
    ##### initialize logging info and output #######
    logfile = open(os.path.join(params.outfolder, "logfile.txt"), 'w')
    params.filtered_reads = open(os.path.join(params.outfolder, "filtered_reads.fa"), 'w')
    params.logfile = logfile
    for arg, val in params.__dict__.items():
        write_output.logger("{0}:\t\t\t {1}".format(arg, val), params.logfile, timestamp=False)
    # print(params.__dict__)
    write_output.logger("Starting.", params.logfile)
    ###################################################


def run_pipeline(params):
    """
        runs:
            1. get_candidates() to find candidate transcripts out of CCS full length reads
            2. stat_filter() to filter candidates and only output significant ones
    """
    total_start = time()
    initialize_logger(params)
    params.read_file = params.fl_reads
    # if params.barcodes:
    #     # need to redirect path to reads to the new path with modified reads as well
    #     filtered_file = identify_barcodes.remove_barcodes_fcn(params.read_file, params)
    #     params.read_file = filtered_file

    nearest_neighbor_start = time() 
    candidate_file, read_partition, to_realign = isocon_get_candidates.find_candidate_transcripts(params.read_file, params)
    # for x_acc in alignments_of_reads_to_candidates:
    #     assert len(alignments_of_reads_to_candidates[x_acc]) == 1
    nearest_neighbor_elapsed = time() - nearest_neighbor_start
    write_output.logger('TIME ELAPSED FOR FINDING CANDIDATES WITH nearest_neighborS:{0}'.format(str(nearest_neighbor_elapsed)), params.logfile)

    stats_start = time()
    # original_reads = {acc: seq for (acc, seq) in  fasta_parser.read_fasta(open(params.fl_reads, 'r'))}
    # candidates = {acc: seq for (acc, seq) in  fasta_parser.read_fasta(open(candidate_file, 'r'))}
    # G_star, partition_of_X = partitions.partition_strings_2set(original_reads, candidates, params.read_file, candidate_file, params)

    isocon_statistical_test.stat_filter_candidates(params.read_file, candidate_file, read_partition, to_realign, params)
    stats_elapsed = time() - stats_start
    write_output.logger('TIME ELAPSED FOR STATISTICAL TEST OF CANDIDATES:{0}'.format(str(stats_elapsed)), params.logfile)

    if params.cleanup:
        clean_dir(params)
    total_elapsed = time() - total_start
    write_output.logger('TOTAL TIME ELAPSED FOR IsoCon:{0}'.format(str(total_elapsed)), params.logfile)


if __name__ == '__main__':

    # create the top-level parser
    parent_parser = argparse.ArgumentParser("Pipeline for obtaining non-redundant haplotype specific transcript isoforms using PacBio IsoSeq reads.")
    # parent_parser.add_argument('--cleanup', action='store_true', help='Remove everything except logfile.txt, candidates_converged.fa and final_candidates.fa in output folder.')
    parent_parser.add_argument('--version', action='version', version='%(prog)s 0.3.3')

    subparsers = parent_parser.add_subparsers(help='help for subcommand')

    # # create the parser for the "pipeline" command
    pipeline = subparsers.add_parser('pipeline', help='Run the entire pipeline, i.e., \n\n {0}'.format(run_pipeline.__doc__))
    pipeline_required = pipeline.add_argument_group('required arguments')
    pipeline_required.add_argument('-fl_reads', required=True, type=str, help='Fast<a/q> file pacbio Reads of Insert. ')
    pipeline_required.add_argument('-outfolder', required=True, type=str, help='Outfolder. ')

    pipeline.add_argument('--ccs', type=str, help='BAM/SAM file with CCS sequence predictions. ')

    pipeline.add_argument('--nr_cores', dest='nr_cores', type=int, default = 16, help='Number of cores to use. [default = 16]')
    pipeline.add_argument('--verbose', dest='verbose', action='store_true', help='This will print more information abount workflow and provide plots of similarity network etc. ')
    pipeline.add_argument('--neighbor_search_depth', type=int, default=2**32, help='Maximum number of pairwise alignments in search matrix to find nearest_neighbor. [default =2**32]')
    pipeline.add_argument('--min_exon_diff', type=int, default=20, help='Minimum consequtive base pair difference between two neigborss in order to remove edge. If more than this nr of consequtive base pair difference, its likely an exon difference. [default =20]')
    pipeline.add_argument('--min_candidate_support', type=int, default=2, help='Required minimum number of reads converged to the same sequence to be included in statistical test. [default 2]')

    pipeline.add_argument('--p_value_threshold', type=float, default=0.01, help='Threshold for statistical test, filter everythin below this threshold . [default = 0.01]')
    pipeline.add_argument('--min_test_ratio', type=int, default=5, help="Don't do tests where candidate c has more than <min_test_ratio> reads assigned to itself compared to the reference t, calculated as test_ratio = c/t, because c will likely be highly significant [default = 5]")
    pipeline.add_argument('--max_phred_q_trusted', dest='max_phred_q_trusted', type=int, default = 43, help='Maximum PHRED quality score trusted (T), linerarly remaps quality score interval [0,93] --> [0, T]. Quality scores may have some uncertainty since T is estimated from a consensus caller algorithm. ')
    pipeline.add_argument('--ignore_ends_len', type=int, default=15, help='Number of bp to ignore in ends. If two candidates are identical except in ends of this size, they are collapsed and the longest common substing is chosen to represent them. In statistical test step, the nearest neighbors are found based on ignoring the ends of this size. Also indels "hanging off" ends of this size will not be tested. [default 15].')
    pipeline.add_argument('--cleanup', action='store_true', help='Remove everything except logfile.txt, candidates_converged.fa and final_candidates.fa in output folder. [default = False]')
    pipeline.add_argument('--prefilter_candidates', action='store_true', help='Filter candidates if they are not consensus over any base pair in the candidate transcript formed from them, this can reduce runtime without significant loss in true candidates. [default = False]')
    pipeline.set_defaults(which='pipeline')


    # create the parser for the "get_candidates" command
    get_candidates = subparsers.add_parser('get_candidates', help='Get candidate transcripts with nearest_neighbor approach.')
    get_candidates_required = get_candidates.add_argument_group('required arguments')
    get_candidates_required.add_argument('-fl_reads', type=str, required=True, help='Full length RoIs. ')
    get_candidates_required.add_argument('-outfolder', type=str, required=True, help='Outfolder. ')
    
    get_candidates.add_argument('--nr_cores', dest='nr_cores', type=int, default = 16, help='Number of cores to use. [default = 16]')
    get_candidates.add_argument('--verbose', dest='verbose', action='store_true', help='This will print more information abount workflow and provide plots of similarity network etc. ')
    get_candidates.add_argument('--neighbor_search_depth', type=int, default=2**32, help='Maximum number of pairwise alignments in search matrix to find nearest_neighbor. [default =2**32]')
    get_candidates.add_argument('--min_exon_diff', type=int, default=20, help='Minimum consequtive base pair difference between two neigborss in order to remove edge. If more than this nr of consequtive base pair difference, its likely an exon difference. [default =20]')
    get_candidates.add_argument('--min_candidate_support', type=int, default=2, help='Required minimum number of reads converged to the same sequence to be included in statistical test.')
    get_candidates.add_argument('--ignore_ends_len', type=int, default=15, help='Number of bp to ignore in ends. If two candidates are identical expept in ends of this size, they are collapses and the longest common substing is chosen to represent them. In statistical test step, nearest_neighbors are found based on ignoring the ends of this size. Also indels in ends will not be tested. [default ignore_ends_len=15].')
    get_candidates.add_argument('--cleanup', action='store_true', help='Remove everything except logfile.txt, candidates_converged.fa and final_candidates.fa in output folder.')
    get_candidates.add_argument('--prefilter_candidates', action='store_true', help='Filter candidates if they are not consensus over at least one base pair, this can reduce runtime without significant loss in true candidates.')
    get_candidates.set_defaults(which='get_candidates')


    # create the parser for the "stat_filter" command
    stat_filter = subparsers.add_parser('stat_filter', help='Get candidate transcripts with nearest_neighbor approach.')
    stat_filter_required = stat_filter.add_argument_group('required arguments')
    stat_filter_required.add_argument('-fl_reads', type=str, required=True, help='Full length RoIs. ')
    stat_filter_required.add_argument('-candidates', type=str, required=True, help='Fasta file of already generated candidates. ')
    stat_filter_required.add_argument('-outfolder', type=str, required=True, help='Outfolder. ')

    stat_filter.add_argument('--ccs', type=str, help='BAM/SAM file with CCS sequence predictions. ')

    stat_filter.add_argument('--nr_cores', dest='nr_cores', type=int, default = 16, help='Number of cores to use. [default = 16]')
    stat_filter.add_argument('--verbose', dest='verbose', action='store_true', help='This will print more information abount workflow and provide plots of similarity network etc. ')
    stat_filter.add_argument('--neighbor_search_depth', type=int, default=2**32, help='Maximum number of pairwise alignments in search matrix to find nearest_neighbor. [default =2**32]')
    stat_filter.add_argument('--min_exon_diff', type=int, default=20, help='Minimum consequtive base pair difference between two neigborss in order to remove edge. If more than this nr of consequtive base pair difference, its likely an exon difference. [default =20]')
    stat_filter.add_argument('--min_candidate_support', type=int, default=2, help='Required minimum number of reads converged to the same sequence to be included in statistical test. Can only be applies if fasta accessions comes from previous step in pipline')
    stat_filter.add_argument('--ignore_ends_len', type=int, default=15, help='Number of bp to ignore in ends. If two candidates are identical expept in ends of this size, they are collapses and the longest common substing is chosen to represent them. In statistical test step, nearest_neighbors are found based on ignoring the ends of this size. Also indels in ends will not be tested. [default ignore_ends_len=15].')
    stat_filter.add_argument('--p_value_threshold', type=float, default=0.01, help='Threshold for statistical test, filter everythin below this threshold . [default = 0.01]')
    stat_filter.add_argument('--min_test_ratio', type=int, default=5, help="Don't do tests where candidate c has more than <min_test_ratio> reads assigned to itself compared to the reference t, calculated as test_ratio = c/t, because c will likely be highly significant [default = 5]")
    stat_filter.add_argument('--max_phred_q_trusted', dest='max_phred_q_trusted', type=int, default = 43, help='Maximum PHRED quality score trusted (T), linerarly remaps quality score interval [0,93] --> [0, T]. Quality scores may have some uncertainty since T is estimated from a consensus caller algorithm. ')
    stat_filter.add_argument('--cleanup', action='store_true', help='Remove everything except logfile.txt, candidates_converged.fa and final_candidates.fa in output folder.')
    stat_filter.set_defaults(which='stat_filter')

    # stat_filter.add_argument('--consensus', type=str, help='Additional canditates supplied by user (This still runs the nearest_neighbor step, but additional candidates can help with guidance). ')

    if len(sys.argv) == 1:
        sys.argv.append('-h')

    args = parent_parser.parse_args()


    try:
        open(args.fl_reads,"r")
    except IOError as e:
        sys.exit("couldn't find read file: " + args.fl_reads + " check that the path is correct and that the file exists")

    if args.fl_reads[-1] == "q":
        args.is_fastq = True
    elif args.fl_reads[-1] == "a":
        args.is_fastq = False
    else:
        sys.stdout.write(args.fl_reads, "has to be a vaild fasta or fastq file and have a file extension that ends with 'a' (fasta, fa) or 'q' (fastq, fq).\n")
        sys.exit()

    if args.ccs and args.is_fastq:
        sys.stdout.write('Warning: both --ccs and fastq file specified. IsoCon will use quality values from fastq file. If this is not desired, convert the fastq to a fasta file.\n')


    params = isocon_parameters.Parameters(**args.__dict__)
    params.minimap_alignments = os.path.join(params.outfolder, "minimapped")

    if not os.path.exists(params.outfolder):
        os.makedirs(params.outfolder)
    if params.verbose:
        params.plotfolder = os.path.join(params.outfolder, 'plots')
        if not os.path.exists(params.plotfolder):
            os.makedirs(params.plotfolder)
        params.develop_logfile = open(os.path.join(params.outfolder, 'develop_logfile.txt'), "w")

    params.tempfolder = os.path.join(params.outfolder, 'alignments')
    if not os.path.exists(params.tempfolder):
        os.makedirs(params.tempfolder)


    if args.which == 'pipeline':
        run_pipeline(params)
    elif args.which == 'remove_barcodes':
        run_remove_barcodes(params)
    elif args.which == 'get_candidates':
        run_get_candidates(params)
    elif args.which == 'stat_filter':
        run_stat_filter(params)
    else:
        print('invalid call')
