#!/usr/bin/env python3

import os
import argparse

from fqutil import Fastq, Read, prefix_extension

def main():
    parser = argparse.ArgumentParser(
            description='Re-pair reads from discordant FASTQ files (for instance, after filtering reads by quality).')
    parser.add_argument('-u', '--keep-unique', default=False, const=True, action='store_const',
            help='Keep reads without a corresponding match. Unique reads are output to a separate file (filename_unique.fastq).')
    parser.add_argument('-o', '--output-dir', nargs=1, default=[None], type=str,
            help='Specify a custom output directory (defaults to directory of input files).')
    parser.add_argument('fastq1', nargs=1, type=str,
            help='First FASTQ file to be re-paired. May be gzipped.')
    parser.add_argument('fastq2', nargs=1, type=str,
            help='Second FASTQ file to be re-paired. May be gzipped.')
    argv = parser.parse_args()

    # handle output directory properly
    fq1 = argv.fastq1[0]
    fq2 = argv.fastq2[0]
    output_dir = argv.output_dir[0]
    if output_dir is None:
        output_dir = os.path.dirname(fq1)
    elif not os.path.exists(output_dir):
        os.makedirs(output_dir)

    # index the second file for random-access
    with Fastq(fq2) as fastq:
        idx = fastq.to_dict()
        fq2_total = fastq.readno

    # proceed through first file and look for matches against our index.
    out1 = os.path.join(output_dir, os.path.basename(fq1))
    out2 = os.path.join(output_dir, os.path.basename(fq2))
    if argv.keep_unique:
        fq1_unique = Fastq(prefix_extension(out1, '_unique'), 'w')
    
    with Fastq(fq1) as fastq, \
            Fastq(prefix_extension(out1, '_common'), 'w') as fq1_common, \
            Fastq(prefix_extension(out2, '_common'), 'w') as fq2_common:
        
        kept = 0
        while True:
            read = fastq.read()
            if read is None:
                break

            read_id = read.unique_id()
            if read_id in idx.keys():
                # match found in index - dump both reads to disk and remove from index.
                read.write(fq1_common)
                seq, qual = idx.pop(read_id)
                Read(read_id + '\n', seq, qual).write(fq2_common)

                kept += 1
            elif argv.keep_unique:
                read.write(fq1_unique)

        fq1_total = fastq.readno

    if argv.keep_unique:
        fq1_unique.close()
        # all remaining keys in dictionary are the unique reads for fastq2
        with Fastq(prefix_extension(out2, '_unique'), 'w') as fq2_unique:
            for k, v in idx.items():
                fq2_unique.writelines([k, '\n', v[0], '+\n', v[1]])
    
    print_stats(fq1, fq1_total, kept)
    print_stats(fq2, fq2_total, kept)
    print('{} reads paired.'.format(kept))


def print_stats(filename, total, kept):
    n_uniq = total - kept
    if kept == 0:
        p_uniq = 100 - (kept / total * 100.)
    else:
        p_uniq = 100

    print('{}: {} unique reads identified out of a total of {} ({:.1f} unique)'
        .format(filename, n_uniq, total, p_uniq))


if __name__ == '__main__':
    main()
