#!/usr/bin/env python
"""
Merge multiple metapeptide database files into a single metapeptide database. Optionally, filter simultaneously.
"""

import argparse
import logging
from datetime import datetime
from sixgill import metapeptides
import sys
from Bio import bgzf
from collections import OrderedDict
import csv
import os

__author__ = "Damon May"
__copyright__ = "Copyright (c) 2016 Damon May"
__license__ = "Apache 2.0"
__version__ = "0.1"

logger = logging.getLogger(__name__)


def declare_gather_args():
    """
    Declare all arguments, parse them, and return the args dict.
    Does no validation beyond the implicit validation done by argparse.
    return: a dict mapping arg names to values
    """

    # declare args
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('metapeptidedbfiles', type=argparse.FileType('r'), nargs='+',
                        help='input metapeptide database files')
    parser.add_argument('--out', type=argparse.FileType('w'), required=True,
                        help='output file')

    parser.add_argument('--debug', action="store_true", help='Enable debug logging')
    return parser.parse_args()


def main():
    args = declare_gather_args()
    # logging
    logging.basicConfig(level=logging.INFO,
                        format="%(asctime)s %(levelname)s: %(message)s")
    if args.debug:
        logger.setLevel(logging.DEBUG)
        # any module-specific debugging goes below

    script_start_time = datetime.now()
    logger.debug("Start time: %s" % script_start_time)

    read_written_xvals = []
    read_written_yvals = []
    ns_read_per_file = []

    out_tempfilename = args.out.name + '.temp.gz'
    out_tempfile = bgzf.BgzfWriter(out_tempfilename, "w")
    print("Created tempfile %s" % out_tempfilename)

    # map from metapeptides to counts of reads they occur in and min qual score
    metapeptideseq_readcount_minqualscore_altseqs_map = OrderedDict()
    n_lines_read = 0
    out_tempfile.write('\t'.join(metapeptides.METAPEPTIDEDB_COLUMNS) + "\n")
    for metapeptide_file in args.metapeptidedbfiles:
        metapeptide_filename = metapeptide_file.name
        metapeptide_file = bgzf.BgzfReader(metapeptide_filename)
        print("Reading file %s..." % metapeptide_filename)
        n_lines_this_file = 0
        for metapeptide in metapeptides.read_metapeptides(metapeptide_file):
            if n_lines_read % 5000000 == 0:
                print("    Read %d lines (all files). Wrote %d to temp file...." %
                      (n_lines_read, len(metapeptideseq_readcount_minqualscore_altseqs_map)))
                read_written_xvals.append(n_lines_read)
                read_written_yvals.append(len(metapeptideseq_readcount_minqualscore_altseqs_map))
                sys.stdout.flush()
                out_tempfile.flush()
            n_lines_this_file += 1
            n_lines_read += 1

            if metapeptide.sequence in metapeptideseq_readcount_minqualscore_altseqs_map:
                # not a new one. Update our recordkeeping
                data_list_thismetapeptide = metapeptideseq_readcount_minqualscore_altseqs_map[metapeptide.sequence]
                data_list_thismetapeptide[0] += metapeptide.n_reads
                data_list_thismetapeptide[1] = min(data_list_thismetapeptide[1],
                                                                       metapeptide.min_qualscore)
                nt_seqs_thismetapeptide = data_list_thismetapeptide[2]
                if nt_seqs_thismetapeptide is None:
                    data_list_thismetapeptide[2] = metapeptide.nt_seqs
                else:
                    for nt_seq in metapeptide.nt_seqs:
                        if nt_seq not in nt_seqs_thismetapeptide:
                            nt_seqs_thismetapeptide.append(nt_seq)
            else:
                # Got a new one. Put it in the map, and write it out if we're writing
                out_tempfile.write(metapeptide.make_output_line() + '\n')
                # got a new one that we want to keep
                metapeptideseq_readcount_minqualscore_altseqs_map[metapeptide.sequence] = [metapeptide.n_reads,
                                                                                         metapeptide.min_qualscore,
                                                                                         metapeptide.nt_seqs]
        ns_read_per_file.append(n_lines_this_file)
        print("Read %d lines this file" % n_lines_this_file)
        n_lines_read += n_lines_this_file

    print("Done reading input databases. Read %d metapeptide lines, %d unique metapeptides" %
          (n_lines_read, len(metapeptideseq_readcount_minqualscore_altseqs_map)))

    out_tempfile.close()
    # close the temp file for writing
    print("Wrote temp file %s" % out_tempfilename)

    # open it right back up again for reading
    tempfile_csvreader = csv.DictReader(bgzf.BgzfReader(out_tempfilename), delimiter='\t')

    out_metapeptide_file = bgzf.BgzfWriter(args.out.name, "w")
    print("Building output file %s with count, min quality data..." % args.out.name)
    out_metapeptide_file.write("\t".join(metapeptides.METAPEPTIDEDB_COLUMNS) + '\n')

    # update the temp file rows with correct readcount and minqualscore data, write the real file
    n_written = 0
    for metapeptideseq in metapeptideseq_readcount_minqualscore_altseqs_map:
        tempfile_row = tempfile_csvreader.next()
        # paranoiacally check that we're merging the right rows
        assert(tempfile_row['sequence'] == metapeptideseq)

        # check readcount and qualscore. If they both pass, update the row appropriately and write it
        readcount, minqualscore, altseqs = metapeptideseq_readcount_minqualscore_altseqs_map[metapeptideseq]
        tempfile_row['n_reads'] = str(readcount)
        tempfile_row['min_qualscore'] = str(minqualscore)
        # update the nt_sequence column to contain all sequences
        nt_seqs = tempfile_row['nt_sequence'].split(',')
        for alt_seq in altseqs:
            if alt_seq not in nt_seqs:
                altseqs.append(alt_seq)
        tempfile_row['nt_sequence'] = ','.join(nt_seqs)
        out_metapeptide_file.write("\t".join([tempfile_row[field] for field in tempfile_csvreader.fieldnames]) + '\n')
        n_written += 1

    out_metapeptide_file.close()

    os.remove(out_tempfilename)
    print("Deleted temp file")

    print("Wrote %d entries to metapeptide database file %s." % (n_written, args.out.name))

    print("Done.")


main()
