#!/usr/bin/env python3
"""Merge a minos VCF with a GVCF at certain positions (driven by the catalogue)."""
import argparse
import gumpy

from vcf_subset import subset_vcf

from pathlib import Path

def fetch_minos_positions(minos_vcf: Path) -> set[int]:
    """Given a minos VCF, return the positions to exclude from the gvcf.

    Args:
        minos_vcf (Path): Path to the minos VCF file.
    Returns:
        set[int]: The positions to exclude.
    """
    vcf = gumpy.VCFFile(minos_vcf.as_posix(), ignore_filter=True)
    positions = set()
    for pos, m_type in vcf.calls:
        if m_type != "indel":
            # Non-indels need no more checking as they are exactly in the VCF
            positions.add(pos)
        else:
            # Check if the indel is a deletion or insertion (deleted bases need excluding too)
            for item in vcf.calls[(pos, m_type)]:
                call, bases = item["call"]
                if call == "ins":
                    positions.add(pos)
                else:
                    positions.update(range(pos, pos + len(bases)))
    return positions
            


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--minos_vcf", help="The minos VCF filepath", required=True)
    parser.add_argument("--gvcf", help="The GVCF filepath", required=True)
    parser.add_argument("--resistant-positions", help="Path to list of resistant sites", required=True)
    parser.add_argument("--output", help="The output VCF file path", required=True)
    args = parser.parse_args()

    # Sanity check that args are paths
    minos_path = Path(args.minos_vcf)
    gvcf_path = Path(args.gvcf)
    resistant_positions_path = Path(args.resistant_positions)

    output_path = Path(args.output)

    # Sanity checking arguments
    if not minos_path.exists() or not gvcf_path.exists() or not resistant_positions_path.exists():
        raise FileNotFoundError("One or more of the input files does not exist!")
    if output_path.exists():
        raise FileExistsError("Output file already exists!")
    
    # Read in the resistant positions
    with open(resistant_positions_path) as f:
        resistant_positions = set([int(line.strip()) for line in f])
    
    minos_positions = fetch_minos_positions(minos_path)
    to_fetch = sorted(list(resistant_positions - minos_positions))
    print(f"Fetching {len(to_fetch)} positions from the GVCF")
    # fetch_strs = set([str(pos) for pos in to_fetch])

    # Pull out these positions from the GVCF
    gvcf_headers, subset = subset_vcf(gvcf_path.as_posix(), to_fetch)

    minos_headers, minos_values = subset_vcf(minos_path.as_posix(), [])
    
    # Pull out header parts to catch parts which need adding
    minos_format = [header for header in minos_headers if "##FORMAT" in header]
    minos_info = [header for header in minos_headers if "##INFO" in header]
    minos_filter = [header for header in minos_headers if "##FILTER" in header]

    gvcf_format = [header for header in gvcf_headers if "##FORMAT" in header]
    gvcf_info = [header for header in gvcf_headers if "##INFO" in header]
    gvcf_filter = [header for header in gvcf_headers if "##FILTER" in header]

    missing_format = [header for header in gvcf_format if header not in minos_format]
    missing_info = [header for header in gvcf_info if header not in minos_info]
    missing_filter = [header for header in gvcf_filter if header not in minos_filter]

    
    minos_misc_headers = [header for header in minos_headers if "##FORMAT" not in header and "##INFO" not in header and "##FILTER" not in header and "#CHROM" not in header]

    chrom_line = [header for header in minos_headers if "#CHROM" in header][0]

    with open(output_path, "w") as f:
        for misc in minos_misc_headers:
            f.write(misc + "\n")
        
        # Merged format
        for header in minos_format:
            f.write(header + "\n")
        for header in missing_format:
            f.write(header + "\n")
        
        # Merged info
        for header in minos_info:
            f.write(header + "\n")
        for header in missing_info:
            f.write(header + "\n")
        
        # Merged filter
        for header in minos_filter:
            f.write(header + "\n")
        for header in missing_filter:
            f.write(header + "\n")
        
        f.write(chrom_line + "\n")

        # Minos rows
        minos_positions = []
        for row in minos_values:
            r = row.split("\t")
            pos = int(r[1])
            minos_positions.append(pos)
            f.write(row + "\n")
            
        # GVCF rows
        for row in subset:
            # Replace DP with COV as GVCF doesn't have a COV-like field
            if ":COV" not in row and "\tCOV:" not in row:
                if ":DP" in row:
                    row = row.replace(":DP", ":COV")
                elif "\tDP:" in row:
                    row = row.replace("\tDP:", "\tCOV:")

            row = row.split("\t")
            pos = int(row[1])
            if pos in minos_positions:
                # We should already be filtering out positions, but in cases of dels, the start can slip through
                # Catch duplicates in these cases
                continue
            
            # GVCF doesn't explicitly call filter passes, so ensure the calls are picked up
            row[6] = "PASS" if row[6] == "." else row[6]
            f.write("\t".join(row) + "\n")
        



    


