#!/usr/bin/env python3

__version__ = '0.4.3'
__author__ = 'David Heller'

import sys
import argparse
import os
import re
import pickle
import gzip
import logging
import configparser
import pysam

from time import strftime, localtime

from svim.SVIM_alignment import run_alignment
from svim.SVIM_COLLECT import analyze_alignment_file_coordsorted, analyze_alignment_file_querysorted
from svim.SVIM_CLUSTER import cluster_sv_signatures, write_signature_clusters_bed, write_signature_clusters_vcf, plot_histograms
from svim.SVIM_COMBINE import combine_clusters


def parse_arguments():
    parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
                                     description="""SVIM (pronounced SWIM) is a structural variant caller for long reads. 
It discriminates five different variant classes: deletions, tandem and interspersed duplications, 
inversions and novel element insertions. SVIM is unique in its capability of extracting both the genomic origin and 
destination of duplications.

SVIM consists of three major steps:
- COLLECT detects signatures for SVs in long read alignments
- CLUSTER merges signatures that come from the same SV
- COMBINE combines clusters from different genomic regions and classifies them into distinct SV types

SVIM can process two types of input. Firstly, it can detect SVs from raw reads by aligning them to a given reference genome first ("SVIM.py reads [options] working_dir reads genome").
Alternatively, it can detect SVs from existing reads alignments in SAM/BAM format ("SVIM.py alignment [options] working_dir bam_file").
""")
    subparsers = parser.add_subparsers(help='modes', dest='sub')
    parser.add_argument('--version', '-v', action='version', version='%(prog)s {version}'.format(version=__version__))

    parser_fasta = subparsers.add_parser('reads', help='Detect SVs from raw reads. Align reads to given reference genome first.')
    parser_fasta.add_argument('working_dir', type=str, help='working directory')
    parser_fasta.add_argument('reads', type=str, help='Read file (FASTA, FASTQ, gzipped FASTA and FASTQ)')
    parser_fasta.add_argument('genome', type=str, help='Reference genome file (FASTA)')
    group_fasta_collect = parser_fasta.add_argument_group('COLLECT')
    group_fasta_collect.add_argument('--min_mapq', type=int, default=20, help='Minimum mapping quality of reads to consider')
    group_fasta_collect.add_argument('--min_sv_size', type=int, default=40, help='Minimum SV size to detect')
    group_fasta_collect.add_argument('--max_sv_size', type=int, default=100000, help='Maximum SV size to detect')
    group_fasta_collect.add_argument('--skip_indel', action='store_true', help='disable signature collection from within read alignments')
    group_fasta_collect.add_argument('--skip_segment', action='store_true', help='disable signature collection from between read alignments')
    group_fasta_collect.add_argument('--cores', type=int, default=1, help='CPU cores to use for alignment with ngmlr')
    group_fasta_collect.add_argument('--aligner', type=str, default="ngmlr", choices=["ngmlr", "minimap2"], help='tool for read alignment: ngmlr or minimap2 (default: ngmlr)')
    group_fasta_collect.add_argument('--nanopore', action='store_true', help='use Nanopore settings for read alignment (default: off)')
    group_fasta_collect.add_argument('--segment_gap_tolerance', type=int, default=10, help='Maximum tolerated gap between adjacent alignment segments')
    group_fasta_collect.add_argument('--segment_overlap_tolerance', type=int, default=5, help='Maximum tolerated overlap between adjacent alignment segments')
    group_fasta_cluster = parser_fasta.add_argument_group('CLUSTER')
    group_fasta_cluster.add_argument('--partition_max_distance', type=int, default=5000, help='Maximum distance in bp between SVs in a partition')
    group_fasta_cluster.add_argument('--distance_normalizer', type=int, default=900, help='Distance normalizer used for span-position distance')
    group_fasta_cluster.add_argument('--cluster_max_distance', type=float, default=0.7, help='Maximum span-position distance between SVs in a cluster')
    group_fasta_combine = parser_fasta.add_argument_group('COMBINE')
    group_fasta_combine.add_argument('--del_ins_dup_max_distance', type=float, default=1.0, help='Maximum span-position distance between the origin of an insertion and a deletion to be flagged as a potential cut&paste insertion')
    group_fasta_combine.add_argument('--trans_destination_partition_max_distance', type=int, default=1000, help='Maximum distance in bp between translocation breakpoint destinations in a partition')
    group_fasta_combine.add_argument('--trans_partition_max_distance', type=int, default=200, help='Maximum distance in bp between translocation breakpoints in a partition')
    group_fasta_combine.add_argument('--trans_sv_max_distance', type=int, default=500, help='Maximum distance in bp between a translocation breakpoint and an SV signature to be combined')
    group_fasta_combine.add_argument('--sample', type=str, default="Sample", help='Sample ID to include in output vcf (default: Sample)')

    parser_bam = subparsers.add_parser('alignment', help='Detect SVs from an existing alignment')
    parser_bam.add_argument('working_dir', type=os.path.abspath, help='working directory')
    parser_bam.add_argument('bam_file', type=argparse.FileType('r'), help='SAM/BAM file with aligned long reads (should be queryname-sorted with samtools sort -n)')
    group_bam_collect = parser_bam.add_argument_group('COLLECT')
    group_bam_collect.add_argument('--min_mapq', type=int, default=20, help='Minimum mapping quality of reads to consider')
    group_bam_collect.add_argument('--min_sv_size', type=int, default=40, help='Minimum SV size to detect')
    group_bam_collect.add_argument('--max_sv_size', type=int, default=100000, help='Maximum SV size to detect')
    group_bam_collect.add_argument('--skip_indel', action='store_true', help='disable signature collection from within read alignments')
    group_bam_collect.add_argument('--skip_segment', action='store_true', help='disable signature collection from between read alignments')
    group_bam_collect.add_argument('--segment_gap_tolerance', type=int, default=10, help='Maximum tolerated gap between adjacent alignment segments')
    group_bam_collect.add_argument('--segment_overlap_tolerance', type=int, default=5, help='Maximum tolerated overlap between adjacent alignment segments')
    group_bam_cluster = parser_bam.add_argument_group('CLUSTER')
    group_bam_cluster.add_argument('--partition_max_distance', type=int, default=5000, help='Maximum distance in bp between SVs in a partition')
    group_bam_cluster.add_argument('--distance_normalizer', type=int, default=900, help='Distance normalizer used for span-position distance')
    group_bam_cluster.add_argument('--cluster_max_distance', type=float, default=0.7, help='Maximum span-position distance between SVs in a cluster')
    group_bam_combine = parser_bam.add_argument_group('COMBINE')
    group_bam_combine.add_argument('--del_ins_dup_max_distance', type=float, default=1.0, help='Maximum span-position distance between the origin of an insertion and a deletion to be flagged as a potential cut&paste insertion')
    group_bam_combine.add_argument('--trans_destination_partition_max_distance', type=int, default=1000, help='Maximum distance in bp between translocation breakpoint destinations in a partition')
    group_bam_combine.add_argument('--trans_partition_max_distance', type=int, default=200, help='Maximum distance in bp between translocation breakpoints in a partition')
    group_bam_combine.add_argument('--trans_sv_max_distance', type=int, default=500, help='Maximum distance in bp between a translocation breakpoint and an SV signature to be combined')
    group_bam_combine.add_argument('--sample', type=str, default="Sample", help='Sample ID to include in output vcf (default: Sample)')

    return parser.parse_args()


def guess_file_type(reads_path):
    if reads_path.endswith(".fa") or reads_path.endswith(".fasta") or reads_path.endswith(".FA"):
        logging.info("Recognized reads file as FASTA format.")
        return "fasta"
    elif reads_path.endswith(".fq") or reads_path.endswith(".fastq") or reads_path.endswith(".FQ"):
        logging.info("Recognized reads file as FASTQ format.")
        return "fastq"
    elif reads_path.endswith(".fa.gz") or reads_path.endswith(".fasta.gz") or reads_path.endswith(".fa.gzip") or reads_path.endswith(".fasta.gzip"):
        logging.info("Recognized reads file as gzipped FASTA format.")
        return "fasta_gzip"
    elif reads_path.endswith(".fq.gz") or reads_path.endswith(".fastq.gz") or reads_path.endswith(".fq.gzip") or reads_path.endswith(".fastq.gzip"):
        logging.info("Recognized reads file as gzipped FASTQ format.")
        return "fastq_gzip"
    elif reads_path.endswith(".fa.fn") or reads_path.endswith(".fq.fn"):
        logging.info("Recognized reads file as file list format.")
        return "list"
    else:
        logging.error("Unknown file ending of file {0}. See github.com/eldariont/svim/wiki/ for supported file endings. Exiting.".format(reads_path))
        return "unknown"


def read_file_list(path):
    file_list = open(path, "r")
    for line in file_list:
        yield line.strip()
    file_list.close()


def main():
    # Fetch command-line options
    options = parse_arguments()
    options.distance_metric = "sl" 

    if not options.sub:
        print("Please choose one of the two modes ('reads' or 'alignment'). See --help for more information.")
        return

    # Set up logging
    logFormatter = logging.Formatter("%(asctime)s [%(levelname)-7.7s]  %(message)s")
    rootLogger = logging.getLogger()
    rootLogger.setLevel(logging.INFO)

    # Create working dir if it does not exist
    if not os.path.exists(options.working_dir):
        os.makedirs(options.working_dir)

    # Create log file
    fileHandler = logging.FileHandler("{0}/SVIM_{1}.log".format(options.working_dir, strftime("%y%m%d_%H%M%S", localtime())), mode="w")
    fileHandler.setFormatter(logFormatter)
    rootLogger.addHandler(fileHandler)

    consoleHandler = logging.StreamHandler()
    consoleHandler.setFormatter(logFormatter)
    rootLogger.addHandler(consoleHandler)

    logging.info("****************** Start SVIM, version {0} ******************".format(__version__))
    logging.info("CMD: python3 {0}".format(" ".join(sys.argv)))
    logging.info("WORKING DIR: {0}".format(os.path.abspath(options.working_dir)))
    for arg in vars(options):
        logging.info("PARAMETER: {0}, VALUE: {1}".format(arg, getattr(options, arg)))

    logging.info("****************** STEP 1: COLLECT ******************")
    if options.sub == 'reads':
        logging.info("MODE: reads")
        logging.info("INPUT: {0}".format(os.path.abspath(options.reads)))
        logging.info("GENOME: {0}".format(os.path.abspath(options.genome)))
        reads_type = guess_file_type(options.reads)
        if reads_type == "unknown":
            return
        elif reads_type == "list":
            # List of read files
            sv_signatures = []
            for index, file_path in enumerate(read_file_list(options.reads)):
                logging.info("Starting processing of file {0} from the list..".format(index))
                reads_type = guess_file_type(file_path)
                if reads_type == "unknown" or reads_type == "list":
                    return
                bam_path = run_alignment(options.working_dir, options.genome, file_path, reads_type, options.cores, options.aligner, options.nanopore)
                aln_file = pysam.AlignmentFile(bam_path)
                sv_signatures.extend(analyze_alignment_file_querysorted(aln_file, options))
        else:
            # Single read file
            bam_path = run_alignment(options.working_dir, options.genome, options.reads, reads_type, options.cores, options.aligner, options.nanopore)
            aln_file = pysam.AlignmentFile(bam_path)
            sv_signatures = analyze_alignment_file_querysorted(aln_file, options)
    elif options.sub == 'alignment':
        logging.info("MODE: alignment")
        logging.info("INPUT: {0}".format(os.path.abspath(options.bam_file.name)))
        aln_file = pysam.AlignmentFile(options.bam_file.name)
        try:
            if aln_file.header["HD"]["SO"] == "coordinate":
                logging.warning("Input BAM file is coordinate-sorted. SVIM can process it but will be less accurate than for queryname-sorted input. It is highly recommended to sort the BAM file by queryname using samtools sort -n.")
                sv_signatures = analyze_alignment_file_coordsorted(aln_file, options)
            elif aln_file.header["HD"]["SO"] == "queryname":
                sv_signatures = analyze_alignment_file_querysorted(aln_file, options)
            else:
                logging.error("Input BAM file needs to be queryname-sorted (highly recommended) or coordinate-sorted. The given file, however, is unsorted according to its header line.")
                return
        except KeyError:
            logging.error("Is the given input BAM file sorted? It does not contain a sorting order in its header line.")
            return

    deletion_signatures = [ev for ev in sv_signatures if ev.type == 'del']
    insertion_signatures = [ev for ev in sv_signatures if ev.type == 'ins']
    inversion_signatures = [ev for ev in sv_signatures if ev.type == 'inv']
    tandem_duplication_signatures = [ev for ev in sv_signatures if ev.type == 'dup']
    translocation_signatures = [ev for ev in sv_signatures if ev.type == 'tra']
    insertion_from_signatures = [ev for ev in sv_signatures if ev.type == 'ins_dup']

    logging.info("Found {0} signatures for deleted regions.".format(len(deletion_signatures)))
    logging.info("Found {0} signatures for inserted regions.".format(len(insertion_signatures)))
    logging.info("Found {0} signatures for inverted regions.".format(len(inversion_signatures)))
    logging.info("Found {0} signatures for tandem duplicated regions.".format(len(tandem_duplication_signatures)))
    logging.info("Found {0} signatures for translocation breakpoints.".format(len(translocation_signatures)))
    logging.info("Found {0} signatures for inserted regions with detected region of origin.".format(len(insertion_from_signatures)))
    
    logging.info("****************** STEP 2: CLUSTER ******************")
    signature_clusters = cluster_sv_signatures(sv_signatures, options)

    # Write SV signature clusters
    logging.info("Finished clustering. Writing signature clusters..")
    write_signature_clusters_bed(options.working_dir, signature_clusters)
    write_signature_clusters_vcf(options.working_dir, signature_clusters, __version__)

    # Create result plots
    plot_histograms(options.working_dir, signature_clusters)

    logging.info("****************** STEP 3: COMBINE ******************")
    combine_clusters(signature_clusters, options.working_dir, options, __version__, aln_file.references, aln_file.lengths, options.sample)

if __name__ == "__main__":
    sys.exit(main())