# Import Yaml or Json config file ----------------------------------------------

configfile: "config_vc.yaml"

# Some trick -------------------------------------------------------------------

get_prefixes = lambda filename: filename.split(".")[0]
REF_PREFIX = get_prefixes(config["reference"])
OUTPUT = config["output"]

# Set up tag for markduplicates
MARKJOB = config["markduplicate"]
MARKTAG = ""
if MARKJOB:
    MARKTAG = "_rmdup"


# Snakemake --------------------------------------------------------------------

rule all:
    input:
        OUTPUT + "_filter.vcf"

rule add_read_group:
    input:
        bam = config["input_bam"]
    output: 
        rgbam = OUTPUT + "_rg.bam"
    log:
        out = OUTPUT + "_logs/picard_tools/AddOrReplaceReadGroups/stdout.logs",
        err = OUTPUT + "_logs/picard_tools/AddOrReplaceReadGroups/stderr.logs"
    params:
        r = " ".join(['%s%s' % (key, value) for (key, value) in \
            config["read_group"].items()]),
        jar = config["picard_jar"],
        src = config["src_module"],
        picard_module = config["picard_module"]
    run:
        if(config["on_cluster"]):
            shell("""
                . {params.src}
                module load {params.picard_module}
                AddOrReplaceReadGroups I={input.bam} O={output.rgbam} \
                {params.r} > {log.out} 2> {log.err}
                """)
        else:
            shell("java -jar {params.jar} AddOrReplaceReadGroups I={input.bam} \
            O={output.rgbam} {params.r} > {log.out} 2> {log.err}")

rule markDuplicate:
    input:
        bam = OUTPUT + "_rg.bam"
    output:
        a = "%s_rg%s.bam" % (OUTPUT, MARKTAG),
    log:
        out = OUTPUT + "_logs/picard_tools/MarkDuplicates/stdout.logs",
        err = OUTPUT + "_logs/picard_tools/MarkDuplicates/stderr.logs",
        metrics = OUTPUT + "_logs/picard_tools/MarkDuplicates/" + OUTPUT + 
            ".metrics"
    params:
        jar = config["picard_jar"],
        src = config["src_module"],
        picard_module = config["picard_module"]
    run:
        if MARKJOB:
            if(config["on_cluster"]):
                shell(""" 
                    . {params.src}
                    module load {params.picard_module}
                    MarkDuplicates I={input.bam} O={output.a} M={log.metrics} \
                    > {log.out} 2> {log.err}
                    """)
            else:
                shell("java -jar {params.jar} MarkDuplicates I={input.bam} \
                    O={output.a} M={log.metrics} > {log.out} 2> {log.err}")

rule samtools_index:
    input: 
        bam = "%s_rg%s.bam" % (OUTPUT, MARKTAG)
    output:
        bai = "%s_rg%s.bam.bai" % (OUTPUT, MARKTAG)
    log:
        OUTPUT + "_logs/samtools_index.logs"
    params:
        src = config["src_module"],
        samtools_module = config["samtools_module"]
    run:
        if(config["on_cluster"]):
            shell("""
                . {params.src}
                module load {params.samtools_module}
                samtools index {input.bam} 2> {log}
                """)
        else:
            shell("samtools index {input.bam} 2> {log}")

rule create_dict:
    input:
        ref = config["reference"]
    output:
        REF_PREFIX + ".dict"
    log:
        out = OUTPUT + "_logs/picard_tools/CreateSequenceDictionary/" + 
            "stdout.logs",
        err = OUTPUT + "_logs/picard_tools/CreateSequenceDictionary/stderr.logs"
    params:
        src = config["src_module"],
        picard_module = config["picard_module"],
        jar = config["picard_jar"]
    run:
        if(config["on_cluster"]):
            shell("""
                . {params.src}
                module load {params.picard_module}
                CreateSequenceDictionary R={input.ref} O={output} \
                > {log.out} 2> {log.err}
                """)
        else:
            shell("java -jar {params.jar} CreateSequenceDictionary \
                R={input.ref} O={output} > {log.out} 2> {log.err}")

rule target_indel:
    input:
        bam = "%s_rg%s.bam" % (OUTPUT, MARKTAG),
        bai = "%s_rg%s.bam.bai" % (OUTPUT, MARKTAG),
        ref = config["reference"],
        ref_dict = REF_PREFIX + ".dict"
    output:
        intervals = OUTPUT + ".intervals"
    log:
        out = OUTPUT + "_logs/gatk/RealignerTargetCreator/stdout.logs",
        err = OUTPUT + "_logs/gatk/RealignerTargetCreator/stderr.logs"
    params:
        src = config["src_module"],
        gatk_module = config["gatk_module"],
        gatk_jar = config["gatk_jar"]
    run:
        if(config["on_cluster"]):
            shell("""
                . {params.src}
                module load {params.gatk_module}
                GenomeAnalysisTK -T RealignerTargetCreator -R {input.ref} \
                -I {input.bam} -o {output.intervals} > {log.out} 2> {log.err}
                """)
        else:
            shell("""
                java -jar {params.gatk_jar} -T RealignerTargetCreator \
                -R {input.ref} -I {input.bam} -o {output.intervals} \
                > {log.out} 2> {log.err}
                """)

rule indel_realigner:
    input:
        bam = "%s_rg%s.bam" % (OUTPUT, MARKTAG),
        ref = config["reference"],
        interval = OUTPUT + ".intervals"
    output:
        bam = OUTPUT + "_realign.bam"
    log:
        out = OUTPUT + "_logs/gatk/IndelRealigner/stdout.logs",
        err = OUTPUT + "_logs/gatk/IndelRealigner/stderr.logs"
    params:
        src = config["src_module"],
        gatk_module = config["gatk_module"],
        gatk_jar = config["gatk_jar"]
    run:
        if(config["on_cluster"]):
            shell("""
                . {params.src}
                module load {params.gatk_module}
                GenomeAnalysisTK -T IndelRealigner -R {input.ref} \
                -I {input.bam} -targetIntervals {input.interval} \
                -o {output.bam} > {log.out} 2> {log.err}
                """)
        else:
            shell("java -jar {params.gatk_jar} -T IndelRealigner \
                -R {input.ref} -I {input.bam} \
                -targetIntervals {input.interval} -o {output.bam} \
                > {log.out} 2> {log.err}")

rule freebayes:
    input:
        bam = OUTPUT + "_realign.bam",
        ref = config["reference"]
    output:
        vcf = OUTPUT + ".vcf"
    log:
        out = OUTPUT + "_logs/freebayes/stdout.logs",
        err = OUTPUT + "_logs/freebayes/stderr.logs"
    params:
        freebayes_params = " ".join(['%s %s' % (key, value) for (key, value) \
            in config["freebayes"].items()]),
        src = config["src_module"],
        freebayes_module = config["freebayes_module"],
        samtools_module = config["samtools_module"]
    run:
        if(config["on_cluster"]):
            shell("""
                . {params.src}
                module load {params.freebayes_module} {params.samtools_module}
                samtools index {input.bam}
                freebayes {params.freebayes_params} -f {input.ref} \
                -b {input.bam} -v {output.vcf} > {log.out} 2> {log.err}
                """)
        else:
            shell("""
                samtools index {input.bam}
                freebayes {params.freebayes_params} -f {input.ref} \
                -b {input.bam} -v {output.vcf} > {log.out} 2> {log.err}
                """)

rule vcf_filter:
    input:
        vcf = OUTPUT + ".vcf"
    output:
        vcf = OUTPUT + "_filter.vcf"
    run:
        from sequana import vcf_filter
        vcf_record = vcf_filter.VCF(input["vcf"])
        vcf_record.filter_vcf(config["vcf_filter"], output["vcf"])
