#!/usr/bin/env python
"""
Read in one or more fastq files. For each read, do a 6-frame translation and add all
metapeptides that pass the specified filtering criteria. If --metagenefile is specified,
start with the output of MetaGene Annotator instead of raw reads.
"""

import argparse
import logging
from datetime import datetime
import pysam
from sixgill import metapeptides
import sys
from collections import OrderedDict
from Bio import bgzf
import csv
import os

__author__ = "Damon May"
__copyright__ = "Copyright (c) 2016 Damon May"
__license__ = "Apache 2.0"
__version__ = "0.1"

# logging
logger = logging.getLogger(__name__)


def declare_gather_args():
    """
    Declare all arguments, parse them, and return the args dict.
    Does no validation beyond the implicit validation done by argparse.
    return: a dict mapping arg names to values
    """

    # declare args
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('fastqfiles', type=argparse.FileType('r'), nargs='+',
                        help='input fastq file(s), bgzipped')
    parser.add_argument('--minlength', type=int, help='min AA length of a metapeptide',
                        default=metapeptides.DEFAULT_MIN_METAPEPTIDE_AALENGTH)
    parser.add_argument('--minqualscore', type=int, help='min base-call phred score across any NT in a metapeptide',
                        default=metapeptides.DEFAULT_MIN_AA_QUALSCORE)
    parser.add_argument('--metagenefile', type=argparse.FileType('r'),
                        help='input MetaGene Annotator output file. Records must be in same linear order as reads in fastqfiles')
    parser.add_argument('--minmetagenescore', type=int, help='minimum MetaGene score',
                        default=metapeptides.METAGENE_SCORE_MISSING)
    parser.add_argument('--minorflength', type=int, help='min length of ORF-portion',
                        default=metapeptides.DEFAULT_MIN_ORF_LENGTH)
    parser.add_argument('--minlongesttryppeplen', type=int,
                        default=metapeptides.DEFAULT_MIN_LONGEST_PEPTIDE_LENGTH,
                        help='minimum length of the longest tryptic peptide')
    parser.add_argument('--maxreads', type=int, help='stop early if we hit this many reads')
    parser.add_argument('--minreadcount', type=int, default=metapeptides.DEFAULT_MIN_READ_COUNT,
                        help='minimum read count')
    parser.add_argument('--out', required=True, type=argparse.FileType('w'),
                        help='Output metapeptide database file')
    parser.add_argument('--outfasta', type=argparse.FileType('w'),
                        help='Output metapeptide fasta database file')
    parser.add_argument('--debug', action="store_true", help='Enable debug logging')
    return parser.parse_args()


def main():
    args = declare_gather_args()
    # logging
    logging.basicConfig(level=logging.INFO,
                        format="%(asctime)s %(levelname)s: %(message)s")
    if args.debug:
        logger.setLevel(logging.DEBUG)
        metapeptides.logger.setLevel(logging.DEBUG)

    script_start_time = datetime.now()
    logger.debug("Start time: %s" % script_start_time)

    # map from metapeptides to counts of reads they occur in and min qual score and alternative seq list.
    # maintains order, for reading back from temp file
    metapeptideseq_readcount_minqualscore_altseqs_map = OrderedDict()

    # for plotting read count vs. # metapeptides
    readsprocessed_nmetapeptides_xvals = []
    readsprocessed_nmetapeptides_yvals = []

    # read back some of the filtering parameters
    print("Min ORF length: %d" % args.minorflength)
    print("Minimum quality score: %d" % args.minqualscore)
    print("Minimum metapeptide AA sequence length: %d" % args.minlength)

    out_filename = args.out.name
    # Write a temporary file for storing full metapeptide information. This
    # can easily become too big to fit in memory.
    out_metapeptide_tempfilename = out_filename + '.temp'
    out_metapeptide_tempfile = bgzf.BgzfWriter(out_metapeptide_tempfilename, "w")
    out_metapeptide_tempfile.write('\t'.join(metapeptides.METAPEPTIDEDB_COLUMNS) + "\n")

    # accounting
    n_reads_processed = -1
    n_frames_discarded_tooshort = 0
    n_frames_discarded_minqualscore = 0
    n_frames_discarded_stopcodon = 0
    n_frames_discarded_longestpep_tooshort = 0
    n_frames_discarded_toofew_trypticsites = 0
    n_candidates_discarded_minmetagenescore = 0

    n_frames_used_in_metapeptides = 0
    n_metapeptides_written_temp = 0

    if args.metagenefile:
        print("Building metapeptides from MetaGene output...")
        sys.stdout.flush()
        # only keep metapeptides that have a MetaGene score > a threshold.
        # So loop on MetaGene results file.
        n_metagenes_processed = 0
        cur_fastq = pysam.FastqFile(args.fastqfiles[0].name)
        cur_fastq_fileidx = 0
        print("Processing fastq file %s..." % args.fastqfiles[cur_fastq_fileidx].name)
        aread = cur_fastq.next()
        for gene_readname, startpos, endpos, strand, frame, score in read_metagene_genes(args.metagenefile):
            if args.maxreads and n_metagenes_processed >= args.maxreads:
                print("STOPPING EARLY, processed %d metagenes" % n_metagenes_processed)
                break
            n_metagenes_processed += 1
            if n_metagenes_processed % 50000 == 0:
                readsprocessed_nmetapeptides_xvals.append(n_metagenes_processed)
                readsprocessed_nmetapeptides_yvals.append(len(metapeptideseq_readcount_minqualscore_altseqs_map))
                if n_metagenes_processed % 100000 == 0:
                    print("    Processed %d records. %d metapeptides so far. Time=%s" %
                          (n_metagenes_processed, len(metapeptideseq_readcount_minqualscore_altseqs_map),
                           datetime.now() - script_start_time))
                sys.stdout.flush()
                if out_metapeptide_tempfile:
                    out_metapeptide_tempfile.flush()
            # check MetaGene score
            if score < args.minmetagenescore:
                n_candidates_discarded_minmetagenescore += 1
                continue

            # Passes. Find the read.
            while aread.name != gene_readname:
                try:
                    aread = cur_fastq.next()
                except Exception:
                    # no more reads this file. Open the next file.
                    cur_fastq_fileidx += 1
                    cur_fastq = pysam.FastqFile(args.fastqfiles[cur_fastq_fileidx].name)
                    print("Processed %d records. %d metapeptides so far" %
                          (n_metagenes_processed, len(metapeptideseq_readcount_minqualscore_altseqs_map)))
                    print("Processing fastq file %s..." % args.fastqfiles[cur_fastq_fileidx].name)
                    aread = cur_fastq.next()

            ntseq = str(aread.sequence)
            # have to subtract offset of 33 from pysam basecall quality
            phred_qualscores = [ord(x) - 33 for x in aread.quality]

            metapeptide, status = metapeptides.extract_frame_metapeptide(ntseq, phred_qualscores,
                                                                         args.minlength,
                                                                         args.minqualscore, args.minorflength,
                                                                         args.minlongesttryppeplen,
                                                                         strand == '-', frame,
                                                                         startpos, endpos,
                                                                         should_keep_with_cterm_stop=True,
                                                                         metagene_score=score)
            # accounting
            if status == metapeptides.METAPEPTIDE_STATUS_BAD_TOOSHORT:
                n_frames_discarded_tooshort += 1
            elif status == metapeptides.METAPEPTIDE_STATUS_BAD_MINQUALSCORE:
                n_frames_discarded_minqualscore += 1
            elif status == metapeptides.METAPEPTIDE_STATUS_BAD_STOPCODON:
                n_frames_discarded_stopcodon += 1
            elif status == metapeptides.METAPEPTIDE_STATUS_BAD_LONGESTPEP_TOOSHORT:
                n_frames_discarded_longestpep_tooshort += 1
            elif status == metapeptides.METAPEPTIDE_STATUS_BAD_TOOFEW_TRYPTICSITES:
                n_frames_discarded_toofew_trypticsites += 1
            if status != metapeptides.METAPEPTIDE_STATUS_OK:
                continue
            # if we haven't encountered it before, write it to the temp file.
            # if we have, update the accounting for its entry
            if metapeptide.sequence not in metapeptideseq_readcount_minqualscore_altseqs_map:
                n_metapeptides_written_temp += 1
                metapeptideseq_readcount_minqualscore_altseqs_map[metapeptide.sequence] = \
                    [1, metapeptide.min_qualscore, None]
                out_metapeptide_tempfile.write(metapeptide.make_output_line() + "\n")
            else:
                data_list_thismetapeptide = metapeptideseq_readcount_minqualscore_altseqs_map[metapeptide.sequence]
                data_list_thismetapeptide[0] += 1
                data_list_thismetapeptide[1] = max(metapeptideseq_readcount_minqualscore_altseqs_map[metapeptide.sequence][1],
                                                  metapeptide.min_qualscore)
                if data_list_thismetapeptide[2] is None:
                    data_list_thismetapeptide[2] = metapeptide.nt_seqs
                else:
                    if metapeptide.nt_seqs[0] not in data_list_thismetapeptide[2]:
                        data_list_thismetapeptide[2].extend(metapeptide.nt_seqs)

            sys.stdout.flush()
    else:
        print("Building metapeptides from all reads...")
        sys.stdout.flush()
        # no metagene filtering; try all 6 reading frames of every read
        for fastq_file in args.fastqfiles:
            if args.maxreads and n_reads_processed >= args.maxreads:
                break
            print("Processing fastq file %s..." % fastq_file.name)

            fastq = pysam.FastqFile(fastq_file.name)
            for aread in fastq:
                if args.maxreads and n_reads_processed >= args.maxreads:
                    print("STOPPING EARLY, processed %d reads" % n_reads_processed)
                    break
                n_reads_processed += 1
                if n_reads_processed % 50000 == 0:
                    readsprocessed_nmetapeptides_xvals.append(n_reads_processed)
                    readsprocessed_nmetapeptides_yvals.append(len(metapeptideseq_readcount_minqualscore_altseqs_map))
                    if n_reads_processed % 100000 == 0:
                        print("    Processed %d records. %d metapeptides so far. Time=%s" %
                              (n_reads_processed, len(metapeptideseq_readcount_minqualscore_altseqs_map),
                               datetime.now() - script_start_time))
                    sys.stdout.flush()
                    if out_metapeptide_tempfile:
                        out_metapeptide_tempfile.flush()

                ntseq = str(aread.sequence)
                # have to subtract offset of 33 from pysam basecall quality
                phred_qualscores = [ord(x) - 33 for x in aread.quality]

                # get all the metapeptides we can from this read
                metapeptides_this_read, (n_discarded_tooshort, n_discarded_minqualscore,
                                         n_discarded_stopcodon, n_discarded_longestpep_tooshort,
                                         n_discarded_toofew_trypticsites) = \
                    metapeptides.extract_read_metapeptides(ntseq, phred_qualscores, args.minlength,
                                                         args.minqualscore, args.minorflength, args.minlongesttryppeplen)
                # accounting
                n_frames_discarded_tooshort += n_discarded_tooshort
                n_frames_discarded_minqualscore += n_discarded_minqualscore
                n_frames_discarded_stopcodon += n_discarded_stopcodon
                n_frames_discarded_longestpep_tooshort += n_discarded_longestpep_tooshort
                n_frames_discarded_toofew_trypticsites += n_discarded_toofew_trypticsites
                n_frames_used_in_metapeptides += len(metapeptides_this_read)

                # for each metapeptide, if we haven't encountered it before, write it to the temp file.
                # if we have, update the accounting for its entry
                for metapeptide in metapeptides_this_read:
                    if metapeptide.sequence not in metapeptideseq_readcount_minqualscore_altseqs_map:
                        n_metapeptides_written_temp += 1
                        metapeptideseq_readcount_minqualscore_altseqs_map[metapeptide.sequence] = \
                            [1, metapeptide.min_qualscore, None]
                        out_metapeptide_tempfile.write(metapeptide.make_output_line() + "\n")
                    else:
                        data_list_thismetapeptide = metapeptideseq_readcount_minqualscore_altseqs_map[metapeptide.sequence]
                        data_list_thismetapeptide[0] += 1
                        data_list_thismetapeptide[1] = max(metapeptideseq_readcount_minqualscore_altseqs_map[metapeptide.sequence][1],
                                                          metapeptide.min_qualscore)
                        if data_list_thismetapeptide[2] is None:
                            data_list_thismetapeptide[2] = metapeptide.nt_seqs
                        else:
                            if metapeptide.nt_seqs[0] not in data_list_thismetapeptide[2]:
                                data_list_thismetapeptide[2].extend(metapeptide.nt_seqs)

                del metapeptides_this_read

            print("Processed %d records. %d metapeptides this file" %
                  (n_reads_processed, len(metapeptideseq_readcount_minqualscore_altseqs_map)))
            sys.stdout.flush()

    # close the temp file for writing
    out_metapeptide_tempfile.close()
    print("frames discarded because too short: %d" % n_frames_discarded_tooshort)
    print("frames discarded because of a low metapeptide minimum quality score: %d" % n_frames_discarded_minqualscore)
    print("frames discarded because of a stop codon: %d" % n_frames_discarded_stopcodon)
    print("frames discarded because the longest peptide was too short: %d" % n_frames_discarded_longestpep_tooshort)
    print("frames discarded because too few tryptic sites: %d" % n_frames_discarded_toofew_trypticsites)
    if args.minmetagenescore >= 0:
        print("frames discarded because too low MetaGene score: %d" % n_candidates_discarded_minmetagenescore)
    print("frames used in metapeptides: %d" % n_frames_used_in_metapeptides)

    print("Wrote %d metapeptides to temp file." % n_metapeptides_written_temp)
    print("Fixing read counts and min quality scores...")

    # open it right back up again for reading
    tempfile_csvreader = csv.DictReader(bgzf.BgzfReader(out_metapeptide_tempfilename), delimiter='\t')

    out_metapeptide_file = bgzf.BgzfWriter(out_filename, "w")
    print("Building database file %s with count, min quality data..." % out_filename)
    out_metapeptide_file.write("\t".join(metapeptides.METAPEPTIDEDB_COLUMNS) + '\n')

    # update the temp file rows with correct readcount and minqualscore data, write the real file
    n_removed_mincount = 0
    n_written = 0

    for metapeptideseq in metapeptideseq_readcount_minqualscore_altseqs_map:
        tempfile_row = tempfile_csvreader.next()
        # paranoiacally check that we're merging the right rows
        assert(tempfile_row['sequence'] == metapeptideseq)

        # check readcount. If it passes, update the row appropriately and write it
        readcount, minqualscore, altseqs = metapeptideseq_readcount_minqualscore_altseqs_map[metapeptideseq]
        if altseqs is None:
            altseqs = []

        # SCAFFOLDING!
        if minqualscore < args.minqualscore:
            quit("minqualscore %f!!!!!!" % minqualscore)


        # IFFY ASSUMPTION for performance reasons: if readcount is < args.minreadcount
        # then there's only one alt seq.
        # This is only guaranteed true if args.minreadcount < 3.
        n_ntseqs = 1
        if readcount >= args.minreadcount:
            tempfile_row['n_reads'] = str(readcount)
            tempfile_row['min_qualscore'] = str(minqualscore)

            # update the nt_sequence column to contain all sequences
            first_nt_seq = tempfile_row['nt_sequence']
            if first_nt_seq not in altseqs:
                altseqs.append(first_nt_seq)
            # altseqs now has all sequences, including first
            n_ntseqs = len(altseqs)
            tempfile_row['nt_sequence'] = ','.join(altseqs)
            out_metapeptide_file.write("\t".join([tempfile_row[field] for field in tempfile_csvreader.fieldnames]) + '\n')
            if args.outfasta:
                args.outfasta.write(">%s\n" % metapeptideseq)
                args.outfasta.write("%s\n" % metapeptideseq)
            n_written += 1
        else:
            n_removed_mincount += 1
    out_metapeptide_file.close()
    os.remove(out_metapeptide_tempfilename)
    if args.minreadcount > 0 or args.minqualscore > 0:
        print("Removed %d due to low read count." % n_removed_mincount)
    if args.outfasta:
        args.outfasta.close()
        print("Wrote %d metapeptides to fasta file %s." % (n_written, args.outfasta.name))

    print("End time: %s. Elapsed time: %s" % (datetime.now(), datetime.now() - script_start_time))
    print("Done.")

    print("extraction time: %s" % (datetime.now() - script_start_time))

def read_metagene_genes(metagene_file):
    cur_readname = None
    while True:
        line = metagene_file.readline()
        if not line:
            break
        if line.startswith("#"):
            if not line.startswith("# gc = ") and not line.startswith("# self"):
                cur_readname = line.strip()[2:]
        else:
            name, startpos_str, endpos_str, strand, frame_str, _, score_str, _, _, _, _ = line.strip().split('\t')
            logger.debug("read: %s: %s" % (cur_readname, line))
            logger.debug("read: %s %s %s %s %s" % (cur_readname, startpos_str, endpos_str, strand, frame_str))
            yield cur_readname, int(startpos_str) - 1, int(endpos_str) - 1, strand, int(frame_str), float(score_str)


main()
