#! /usr/bin/env python

from __future__ import print_function
import os,sys
import argparse

import errno
from time import time

from modules import get_sorted_fastq_for_cluster
from modules import cluster
from modules import p_minimizers_shared

'''
    Below code taken from https://github.com/lh3/readfq/blob/master/readfq.py
'''

def readfq(fp): # this is a generator function
    last = None # this is a buffer keeping the last unprocessed line
    while True: # mimic closure; is it a bad idea?
        if not last: # the first record or a record following a fastq
            for l in fp: # search for the start of the next record
                if l[0] in '>@': # fasta/q header line
                    last = l[:-1] # save this line
                    break
        if not last: break
        name, seqs, last = last[1:].replace(" ", "_"), [], None
        for l in fp: # read the sequence
            if l[0] in '@+>':
                last = l[:-1]
                break
            seqs.append(l[:-1])
        if not last or last[0] != '+': # this is a fasta record
            yield name, (''.join(seqs), None) # yield a fasta record
            if not last: break
        else: # this is a fastq record
            seq, leng, seqs = ''.join(seqs), 0, []
            for l in fp: # read the quality
                seqs.append(l[:-1])
                leng += len(l) - 1
                if leng >= len(seq): # have read enough quality
                    last = None
                    yield name, (seq, ''.join(seqs)); # yield a fastq record
                    break
            if last: # reach EOF before reading enough quality
                yield name, (seq, None) # yield a fasta record instead
                break




def mkdir_p(path):
    try:
        os.makedirs(path)
        print("creating", path)
    except OSError as exc:  # Python >2.5
        if exc.errno == errno.EEXIST and os.path.isdir(path):
            pass
        else:
            raise

def main(args):
    args.outfile = os.path.join(args.outfolder, "sorted.fastq")

    print("started sorting seqs")
    start = time()
    sorted_reads_fastq_file = get_sorted_fastq_for_cluster.main(args)
    read_array = [ (i, 0, acc, seq, qual, float(acc.split("_")[-1])) for i, (acc, (seq, qual)) in enumerate(readfq(open(sorted_reads_fastq_file, 'r')))]
    print("elapsed time sorting:", time() - start)

    print("started imported empirical error probabilities of minimizers shared:")
    start = time()
    p_min_shared = p_minimizers_shared.read_empirical_p()
    p_emp_probs = {}
    for k, w, p, e1, e2 in p_min_shared:
        if int(k) == args.k and abs(int(w) - args.w) <= 2:
            p_emp_probs[(float(e1),float(e2))] = float(p)
            p_emp_probs[(float(e2),float(e1))] = float(p)

    print(p_emp_probs)
    print(len(p_emp_probs))
    print("elapsed time imported empirical error probabilities of minimizers shared:", time() - start)
    # sys.exit()
    print("started clustring")
    start = time()
    clusters, cluster_seq_origin = cluster.cluster_seqs(read_array, p_emp_probs,  args)
    print("elapsed time clustering:", time() - start)

    # reads = {acc: (seq,qual) for (acc, (seq, qual)) in  readfq(open(sorted_reads_fastq_file, 'r'))}
    # nonclustered_outfile = open(os.path.join(reads_outfolder, "non_clustered.fa"), "w")

    outfile = open(os.path.join(args.outfolder,  "final_clusters.csv"), "w")
    origins_outfile = open(os.path.join(args.outfolder,  "final_cluster_origins.csv"), "w")
    nontrivial_cluster_index = 0
    for c_id, all_read_acc in sorted(clusters.items(), key = lambda x: len(x[1]), reverse=True):
        read_cl_id, b_i, acc, c_seq, c_qual, score, error_rate = cluster_seq_origin[c_id]
        origins_outfile.write("{0}\t{1}\t{2}\t{3}\t{4}\t{5}\n".format(read_cl_id, "_".join([item for item in acc.split("_")[:-1]]), c_seq, c_qual, score, error_rate))

        for r_acc in all_read_acc:
            outfile.write("{0}\t{1}\n".format(c_id, "_".join([item for item in r_acc.split("_")[:-1]]) ))
        if len(all_read_acc) > 1:
            nontrivial_cluster_index += 1
    print("Nr clusters larger than 1:", nontrivial_cluster_index) #, "Non-clustered reads:", len(archived_reads))
    print("Nr clusters (all):", len(clusters)) #, "Non-clustered reads:", len(archived_reads))

    outfile.close()
    origins_outfile.close()

    # for cl_id, all_read_acc in sorted(clusters.items(), key = lambda x: len(x[1]), reverse=True):


def write_fastq(args):
    from collections import defaultdict
    clusters = defaultdict(list)

    with open(args.clusters) as f:
        for line in f:
            items = line.strip().split()
            cl_id, acc = items[0], items[1]
            clusters[cl_id].append(acc)

    mkdir_p(args.outfolder)
    reads = { acc : (seq, qual) for acc, (seq, qual) in readfq(open(args.fastq, 'r'))}
    
    for cl_id in clusters:
        r = clusters[cl_id]

        if len(r) >= args.N:
            curr_file = open(os.path.join(args.outfolder, str(cl_id) + ".fastq" ), "w")
            for acc in r:
                seq, qual = reads[acc]
                curr_file.write("@{0}\n{1}\n{2}\n{3}\n".format(acc, seq, "+", qual))
            curr_file.close()



    # curr_id = -1
    # for i, cl_id, acc in enumerate(open(args.clusters, "r")): 
    #     if i == 0:
    #         curr_file = open(os.path.join(args.outfolder, str(cl_id) + ".fastq" ), "w")
    #         seq, qual = reads[acc]
    #         curr_file.write("{0}\n{1}\n{2}\n{3}\n".format(acc, seq, "+", qual))

    #     elif curr_id != cl_id: # new cluster
    #         curr_file.close()
    #         curr_file = open(os.path.join(args.outfolder, str(cl_id) + ".fastq" ), "w")
    #         seq, qual = reads[acc]
    #         curr_file.write("{0}\n{1}\n{2}\n{3}\n".format(acc, seq, "+", qual))

    #     else: # same as before
    #         seq, qual = reads[acc]
    #         curr_file.write("{0}\n{1}\n{2}\n{3}\n".format(acc, seq, "+", qual))

    #     curr_id = cl_id



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="De novo clustering of long-read transcriptome reads", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--version', action='version', version='%(prog)s 0.0.3')

    parser.add_argument('--fastq', type=str,  default=False, help='Path to consensus fastq file(s)')
    parser.add_argument('--flnc', type=str, default=False, help='The flnc reads generated by the isoseq3 algorithm (BAM file)')
    parser.add_argument('--ccs', type=str, default=False, help='Path to consensus BAM file(s)')
    # parser.add_argument('--mapping', action="store_true", help='Only infer clusters by mapping, no alignment is performed.')
    parser.add_argument('--t', dest="nr_cores", type=int, default=8, help='Number of cores allocated for clustering')

    parser.add_argument('--ont', action="store_true", help='Clustering of ONT transcript reads.')
    parser.add_argument('--isoseq', action="store_true", help='Clustering of PacBio Iso-Seq reads.')

    parser.add_argument('--k', type=int, default=15, help='Kmer size')
    parser.add_argument('--w', type=int, default=50, help='Window size')
    parser.add_argument('--min_shared', type=int, default=5, help='Minmum number of minimizers shared between read and cluster')
    parser.add_argument('--mapped_threshold', type=float, default=0.7, help='Minmum mapped fraction of read to be included in cluster. The density of minimizers to classify a region as mapped depends on quality of the read.')
    parser.add_argument('--aligned_threshold', type=float, default=0.4, help='Minmum aligned fraction of read to be included in cluster. Aligned identity depends on the quality of the read.')
    parser.add_argument('--min_fraction', type=float, default=0.8, help='Minmum fraction of minimizers shared compared to best hit, in order to continue mapping.')
    parser.add_argument('--min_prob_no_hits', type=float, default=0.1, help='Minimum probability for i consecutive minimizers to be different between read and representative and still considered as mapped region, under assumption that they come from the same transcript (depends on read quality).')
    parser.add_argument('--outfolder', type=str,  default=None, help='A fasta file with transcripts that are shared between samples and have perfect illumina support.')
    # parser.add_argument('--pickled_subreads', type=str, help='Path to an already parsed subreads file in pickle format')
    parser.set_defaults(which='main')

    subparsers = parser.add_subparsers(help='sub-command help')
    write_fastq_parser = subparsers.add_parser('write_fastq', help='a help')
    write_fastq_parser.add_argument('--clusters', type=str, help='the file "final_clusters.tsv created by isONclust."')
    write_fastq_parser.add_argument('--fastq', type=str, help='Input fastq file')
    write_fastq_parser.add_argument('--outfolder', type=str, help='Output folder')
    write_fastq_parser.add_argument('--N', type=int, default = 0, help='Write out clusters with more or equal than N reads')
    # parser.add_argument('--write_fastq_clusters', default = None, help=' --write_fastq_clusters <N>. Write out clusters with more or equal than N >= 1.')
    write_fastq_parser.set_defaults(which='write_fastq')

    args = parser.parse_args()

    if args.which == 'write_fastq':
        write_fastq(args)
        print("Wrote clusters to separate fastq files.")
        sys.exit(0)

    if (args.fastq and (args.flnc or args.ccs)):
        print("Either (1) only a fastq file, or (2) a ccs and a flnc file should be specified. ")
        sys.exit()

    if (args.flnc != False and args.ccs == False ) or (args.flnc == False and args.ccs != False ):
        print("isONclust needs both the ccs.bam file produced by ccs and the flnc file produced by isoseq3 cluster. ")
        sys.exit()

    if args.ont and args.isoseq :
        print("Arguments mutually exclusive, specify either --isoseq or --ont. ")
        sys.exit()
    elif args.isoseq:
        args.k = 15
        args.w = 50
    elif args.ont:
        args.k = 13
        args.w = 20


    if len(sys.argv)==1:
        parser.print_help()
        sys.exit()
    if not args.fastq and not args.flnc and not  args.ccs:
        parser.print_help()
        sys.exit()


    if args.outfolder and not os.path.exists(args.outfolder):
        os.makedirs(args.outfolder)


    # edlib_module = 'edlib'
    parasail_module = 'parasail'
    # if edlib_module not in sys.modules:
    #     print('You have not imported the {0} module. Only performing clustering with mapping, i.e., no alignment.'.format(edlib_module))
    if parasail_module not in sys.modules:
        print('You have not imported the {0} module. Only performing clustering with mapping, i.e., no alignment!'.format(parasail_module))
        sys.exit(1)
    if 100 < args.w or args.w < args.k:
        print('Please specify a window of size larger or equal to k, and smaller than 100.')
        sys.exit(1)

    main(args)

