from collections import defaultdict
from yaml import load
import os

from snakemake.utils import Paramspace
import pandas as pd
import pathlib

from snakemake.io import chain, expand, regex
import sys
import os
import re
import collections
import random
import datetime
from itertools import repeat


include: "rules/extra_functions.smk"
option=config.get("option","standard") # select standard option for full analysis
runid=config.get("runid","".join(random.choices("abcdefghisz",k=3) + random.choices("123456789",k=5)))
logname = "_".join(["cellsnake",runid, datetime.datetime.now().strftime("%y%m%d_%H%M%S"),option,"log"])

cellsnake_path=config.get('cellsnake_path','') #if called by cellsnake
datafolder=config.get("datafolder","data") #this can be a directory with samples or a single file to process
analyses_folder=config.get("analyses_folder","analyses")
results_folder=config.get("results_folder","results")
is_integrated_sample=config.get("is_integrated_sample",False)
gene_to_plot=[] if config.get("gene_to_plot",None) is None else [config.get("gene_to_plot")]


#configfile: "config.yaml"

files=[]
try:

    if os.path.isdir(datafolder):
        cellranger_new = cellsnake_glob_wildcards(datafolder +  "/{sample}/raw_feature_bc_matrix/matrix.mtx.gz").sample + cellsnake_glob_wildcards(datafolder + "/{sample}/outs/filtered_feature_bc_matrix/matrix.mtx.gz").sample + cellsnake_glob_wildcards(datafolder +  "/{sample}/matrix.mtx.gz").sample
        cellranger_old = cellsnake_glob_wildcards(datafolder +  "/{sample}/raw_feature_bc_matrix/matrix.mtx").sample + cellsnake_glob_wildcards(datafolder + "/{sample}/outs/filtered_feature_bc_matrix/matrix.mtx").sample + cellsnake_glob_wildcards(datafolder +  "/{sample}/matrix.mtx").sample
        h5files, = cellsnake_glob_wildcards(datafolder + "/{sample}/filtered_feature_bc_matrix.h5")
        files=list(filter(lambda i: "/" not in i, cellranger_new + cellranger_old + h5files)) #do not capture subdirectories etc
        print(files)
    
    if os.path.isfile(datafolder):
        file_extension = pathlib.Path(datafolder)
        if (file_extension.suffix).lower() not in [".rds"]:
            files = [file_extension.stem]
            print(files)
        elif is_integrated_sample:
            analyses_folder="analyses_integrated"
            results_folder="results_integrated"
            files = [file_extension.stem]

except:
    print("No samples detected")



#basic paramaters
min_cells=config.get("min_cells",3)
min_features=config.get("min_features",200) #nFeature_RNA
max_features=config.get("max_features","Inf") #nFeature_RNA
max_molecules=config.get("max_molecules","Inf") #nCount_RNA
min_molecules=config.get("min_molecules",0) #nCount_RNA
percent_mt=str(config.get("percent_mt",10)) #if not automatic, this will be used for all samples
percent_rp=config.get("percent_rp",0) #by default, no filtering on ribosomal genes percentage 
highly_variable_features=config.get("highly_variable_features",2000)
variable_selection_method=config.get("variable_selection_method","vst")
doublet_filter= "--doublet.filter" if config.get("doublet_filter",True) in [True,"TRUE","True","T"] and is_integrated_sample is False else ""
metadata_file=config.get("metadata_file","metadata.csv")

#automatic mt filtering
grid_search=config.get("grid_search",False) #requires params file in tsv format


#automatic cluster/resolution detection
#detect_resolution=config.get("detect_resolution",True)

#clustering and normalization paramaters
normalization_method=config.get("normalization_method","LogNormalize")
scale_factor=config.get("scale_factor",10000)
resolution=config.get("resolution","0.8" if is_integrated_sample is False else "0.3") #a double or "auto", default is 0.8

#dimension reduction options: by default both UMAP and TSNE will be plotted with predicted clusters
umap_plot="--umap" if config.get("umap_plot",True) in [True,"TRUE","True","T"] else ""
tsne_plot="--tsne" if config.get("tsne_plot",True) in [True,"TRUE","True","T"] else ""

#Each marker plot contains a DimPlot, by default only UMAP plots will be created (to save space and time), this behavior can be changed.
umap_markers_plot="--umap" if config.get("umap_markers_plot",True) in [True,"TRUE","True","T"] else ""
tsne_markers_plot="--tsne" if config.get("tsne_markers_plot",False) in [True,"TRUE","True","T"] else ""

#Differential expression paramaters
logfc_threshold=config.get("logfc_threshold",0.25)
test_use=config.get("test_use","wilcox")
marker_plots_per_cluster_n=config.get("marker_plots_per_cluster_n",20) #only print top 10 by default


identity_to_analysis=[config.get("identity_to_analysis","seurat_clusters")]

selected_gene_file=config.get("selected_gene_file","markers.tsv")

#enrichment paramaters
algorithm=config.get("algorithm","weight01")
statistics=config.get("statistics","ks")
mapping=config.get("mapping","org.Hs.eg.db")
organism=config.get("organism","hsa")

#GSEA
gsea_file=cellsnake_path + "workflow/bundle/c2.cgp.v2022.1.Hs.symbols.gmt" if config.get("gsea_file") is None else config.get("gsea_file")
gsea_group=config.get("gsea_group","seurat_clusters")


integration_id=config.get("integration_id","integrated")
celltypist_model=config.get("celltypist_model","Immune_All_Low.pkl")
singler_ref=config.get("singler_ref","HumanPrimaryCellAtlasData")


#cellchat
species=config.get("species","human") #cellchat species, only mouse or human

#Kraken DB
kraken_db_folder=config.get("kraken_db_folder",None)
taxa=config.get("taxa","genus")
microbiome_min_cells=config.get("microbiome_min_cells",1)
microbiome_min_features=config.get("microbiome_min_features",3)
confidence=config.get("confidence",0.05)
min_hit_groups=config.get("min_hit_groups",4)

#initialization of paramspace object using defaults
def initialization_of_paramspace(tsv_file,dictionary):
    if os.path.isfile(tsv_file) and grid_search is True:
        par_df = pd.read_table(tsv_file) # if available and use for all samples
    else:
        par_df =  pd.DataFrame(dictionary) # if not available, create using default numbers and use for all samples
    
    return par_df


if is_integrated_sample:
    identity_to_analysis.append("orig.ident")
    percent_mt="auto" #if integrated auto

par_df = initialization_of_paramspace("params.tsv",{"percent_mt":[percent_mt],"resolution":[resolution]})
paramspace=Paramspace(par_df)

def dim_reduction_and_marker_plots(paramspace,identity):
    plots=[]
    if umap_markers_plot:
        plots = plots + expand([results_folder + "/" + s + "/{params}/positive_marker_plots_umap/" + i + "/" for i in identity for s in files],params=list(paramspace.instance_patterns))
    if tsne_markers_plot:
        plots = plots + expand([results_folder + "/" + s + "/{params}/positive_marker_plots_tsne/" + i + "/" for i in identity for s in files],params=list(paramspace.instance_patterns))
    return plots

try:
    if os.path.isfile(selected_gene_file):
        df=pd.read_table(selected_gene_file)
        gene_to_plot += df[df.columns[0]].values.tolist()
except:
    pass

def selected_gene_plot(paramspace,gene_to_plot,identity):
    plots=[]
    if gene_to_plot:
        plots = plots + expand([results_folder + "/" + s +  "/{params}/selected_gene_plots_umap/" + i + "/" + g + ".pdf" for i in identity for g in gene_to_plot for s in files],params=list(paramspace.instance_patterns))
        plots = plots + expand([results_folder + "/" + s +  "/{params}/selected_gene_plots_tsne/" + i + "/" + g + ".pdf" for i in identity for g in gene_to_plot for s in files],params=list(paramspace.instance_patterns))
    return plots

def identity_dependent_dimplot(paramspace,identity):
    plots=[]
    if len(identity) > 0:
        plots = plots + expand([results_folder + "/" + s + "/{params}/plot_dimplot_" + u + "-" + x + ".pdf" for u in ["umap","tsne","pca"] for x in identity for s in files],params=list(paramspace.instance_patterns))
    return plots

def identity_dependent_metrics(paramspace,identity):
    plots=[]
    if len(identity) > 0:
        plots +=  expand([results_folder + "/" + s + "/{params}/metrics/plot_cellcount-" + x + ".pdf"  for x in identity for s in files],params=list(paramspace.instance_patterns))
        plots +=  expand([results_folder + "/" + s + "/{params}/metrics/plot_cellcount_barplot-" + x + ".pdf"  for x in identity for s in files],params=list(paramspace.instance_patterns))
        plots +=  expand([results_folder + "/" + s + "/{params}/metrics/plot_cellcount_barplot-" + x + ".html"  for x in identity for s in files],params=list(paramspace.instance_patterns))
    return plots

def cellchat_plot(paramspace,identity):
    plots=[]
    #identity.append("singler") #add majority voting to the list of identities but for now drop this
    if len(identity) > 0:
        plots = plots + expand([results_folder + "/" + s + "/{params}/cellchat/" + x + "/" for x in identity for s in files],params=list(paramspace.instance_patterns))
    return plots

def celltypist_analysis(paramspace,celltypist_model,identity):
    plots=[]
    if celltypist_model:
        plots += expand([results_folder + "/" + s + "/{params}/celltypist/" + celltypist_model + "/plot_celltypist_dotplot-" + x + ".pdf" for x in identity_to_analysis  for s in files],params=list(paramspace.instance_patterns))
        plots += expand([results_folder + "/" + s + "/{params}/celltypist/" + celltypist_model + "/plot_celltypist_umap-" + x + ".pdf" for x in identity_to_analysis  for s in files],params=list(paramspace.instance_patterns))
        plots += expand([results_folder + "/" + s + "/{params}/celltypist/" + celltypist_model + "/plot_celltypist_tsne-" + x + ".pdf" for x in identity_to_analysis  for s in files],params=list(paramspace.instance_patterns))
    return plots

def gsea_analysis(paramspace,identity):
    outs=[]
    #identity.append("singler")
    if len(set(identity)) > 0 and os.path.isfile(gsea_file):
        outs = outs + expand([results_folder + "/" + s + "/{params}/gsea/table_gsea_output-" + x + ".xlsx" for x in identity for s in files],params=list(paramspace.instance_patterns))
    return outs

def enrichment_analysis(paramspace,identity):
    outs=[]
    #identity.append("singler")
    if len(set(identity)) > 0:
        outs += expand([results_folder + "/" + s + "/{params}/enrichment_analysis/table_GO-enrichment-" + x + ".xlsx" for x in identity for s in files],params=list(paramspace.instance_patterns))
        outs += expand([results_folder + "/" + s + "/{params}/enrichment_analysis/table_GO-geneset_enrichment-" + x + ".xlsx" for x in identity for s in files],params=list(paramspace.instance_patterns))
        outs += expand([results_folder + "/" + s + "/{params}/enrichment_analysis/table_KEGG-enrichment-" + x + ".xlsx" for x in identity for s in files],params=list(paramspace.instance_patterns))
        outs += expand([results_folder + "/" + s + "/{params}/enrichment_analysis/table_KEGG-geneset_enrichment-" + x + ".xlsx" for x in identity for s in files],params=list(paramspace.instance_patterns))
        outs += expand([results_folder + "/" + s + "/{params}/enrichment_analysis/table_KEGG-module_enrichment-" + x + ".xlsx" for x in identity for s in files],params=list(paramspace.instance_patterns))
        outs += expand([results_folder + "/" + s + "/{params}/enrichment_analysis/table_KEGG-module_geneset_enrichment-" + x + ".xlsx" for x in identity for s in files],params=list(paramspace.instance_patterns))
    return outs



def kraken_predictions(paramspace,taxa,identity):
    #bamfiles=cellsnake_glob_wildcards(datafolder + "/{sample}/outs/possorted_genome_bam.bam").sample + cellsnake_glob_wildcards(datafolder + "/{sample}/possorted_genome_bam.bam").sample 
    outs=[]
    if kraken_db_folder is not None:
        outs = outs + expand([results_folder + "/" + s + "/{params}/microbiome/plot_microbiome_full-" + taxa + "-level.pdf" for s in files],params=list(paramspace.instance_patterns))
        outs = outs + expand([results_folder + "/" + s + "/{params}/microbiome/plot_microbiome_dimplot-" + taxa + "-" + x + ".pdf" for x in ["umap","tsne"] for s in files],params=list(paramspace.instance_patterns))
        outs = outs + ["analyses_integrated/seurat/" + integration_id + "-" + taxa + ".rds"]
    elif is_integrated_sample is True and os.path.isfile("analyses_integrated/seurat/" + integration_id + "-" + taxa + ".rds"):
        outs = outs + expand([results_folder + "/" + s + "/{params}/microbiome/plot_integrated_microbiome_dimplot-" + taxa + "-" + x + ".pdf" for x in ["umap","tsne"] for s in files],params=list(paramspace.instance_patterns))
        outs = outs + expand([results_folder + "/" + s + "/{params}/microbiome/plot_integrated_significance-" + taxa + "-" + x + ".pdf" for x in identity for s in files],params=list(paramspace.instance_patterns))
        outs = outs + expand([results_folder + "/" + s + "/{params}/microbiome/table_integrated_significance-" + taxa + "-" + x + ".xlsx" for x in identity for s in files],params=list(paramspace.instance_patterns))
    return outs

def write_main_log(files):

    pathlib.Path("logs/").mkdir(parents=True, exist_ok=True)
    with open("logs/" + logname,"w") as f:
        f.write("Total number of samples processed : " + str(len(files)) + "\n")
        f.write("Sample names in this run : " + " ".join(files) + "\n")
        f.write("Run ID : " + str(runid) + "\n")
        f.write("Option : " + str(option) + "\n")




def write_sample_log(sample,paramspace):
    l=results_folder + "/" + sample + "/logs/"
    pathlib.Path(l).mkdir(parents=True, exist_ok=True)
    with open(l + runid + ".txt","w") as f:
        f.write("Sample name : " + sample + "\n")
        f.write("Run ID : " + str(runid) + "\n")
        f.write("Run option : " + str(option) + "\n")
        f.write("Main directory and params : " + "".join(list(paramspace.instance_patterns)) + "\n")
        f.write("Minimum cells : " + str(min_cells) + "\n")
        f.write("Minimum features (nFeature_RNA) : " + str(min_features) + "\n")
        f.write("Maximum features (nFeature_RNA) : " + str(max_features) + "\n")
        f.write("Maximum molecules (nCount_RNA) : " + str(max_molecules) + "\n")
        f.write("Minimum molecules (nCount_RNA) : " + str(min_molecules) + "\n")
        f.write("Percent mitochondrial gene treshold (smaller than) : " + str(percent_mt) + "\n")
        f.write("Percent ribosomal gene treshold (larger than) : " + str(percent_rp) + "\n")
        f.write("Resolution : " + str(resolution) + "\n")
        f.write("Highly variable genes : " + str(highly_variable_features) + "\n")
        f.write("Doublet filter : " + str(config.get("doublet_filter",True)) + "\n")
        f.write("Normalization method : " + str(normalization_method) + "\n")
        f.write("Scale factor : " + str(scale_factor) + "\n")
        f.write("LogFC treshold : " + str(logfc_threshold) + "\n")
        f.write("DE test use : " + str(test_use) + "\n")
        #f.write("Algorithm (GO enrichment) : " + str(algorithm) + "\n")
        #f.write("Statistics (GO enrichment) : " + str(statistics) + "\n")
        f.write("Mapping (GO and KEGG enrichment) : " + str(mapping) + "\n")
        f.write("Organism (KEGG enrichment) : " + str(organism) + "\n")
        f.write("Species (cellchat) : " + str(species) + "\n")
        #f.write("Ontology (GO enrichment) : " + str(ontology) + "\n")
        f.write("SingleR reference : " + str(singler_ref) + "\n")
        f.write("Celltypist model : " + str(celltypist_model) + "\n")
        f.write("Kraken DB folder : " + str(kraken_db_folder) + "\n")
        f.write("Collapse to this taxanomic level : " + str(taxa) + "\n")
        f.write("Kraken confidence param : " + str(confidence) + "\n")
        f.write("Kraken minimum hit groups : " + str(min_hit_groups) + "\n")


def singler_plots(paramspace,identity):
    plots=[]
    if len(identity) > 0:
        plots +=  expand([results_folder + "/" + s + "/{params}/singler/plot_score_heatmap-" + x + ".pdf" for x in identity_to_analysis for s in files],params=list(paramspace.instance_patterns))
        plots +=  expand([results_folder + "/" + s + "/{params}/singler/plot_clusters-" + x + ".pdf" for x in identity_to_analysis for s in files],params=list(paramspace.instance_patterns))
        plots +=  expand([results_folder + "/" + s + "/{params}/singler/plot_score_heatmap_top-" + x + ".pdf" for x in identity_to_analysis for s in files],params=list(paramspace.instance_patterns))
    return plots


def sample_parameter(paramspace,files):
    outs=[]
    if option in ["standard","advanced"]:
        outs = outs + [
                expand([results_folder + "/" + s + "/{params}/summarized_markers-for-" + x + ".pdf" for x in identity_to_analysis for s in files],params=list(paramspace.instance_patterns)),
                singler_plots(paramspace,identity_to_analysis),
                #expand([results_folder + "/" + s + "/{params}/singler/clusters-celltype_annotation.pdf" for s in files],params=list(paramspace.instance_patterns)),
                #expand([results_folder + "/" + s + "/{params}/singler/score_heatmap_top-celltype_annotation.pdf" for s in files],params=list(paramspace.instance_patterns)),
                expand([results_folder + "/" + s + "/{params}/plot_annotation_tsne.pdf" for s in files],params=list(paramspace.instance_patterns)),
                expand([results_folder + "/" + s + "/{params}/plot_annotation_umap.pdf" for s in files],params=list(paramspace.instance_patterns)),
                expand([results_folder + "/" + s + "/{params}/plot_annotation_pca.pdf" for s in files],params=list(paramspace.instance_patterns)),
                #expand([results_folder + "/" + s + "/{params}/singler/annotation.umap.html" for s in files],params=list(paramspace.instance_patterns)),
                #gsea_analysis(paramspace,[x for x in identity_to_analysis if x != "orig.ident"]),
                celltypist_analysis(paramspace,celltypist_model,identity_to_analysis),
                enrichment_analysis(paramspace,identity_to_analysis),
                expand([results_folder + "/" + s + "/{params}/trajectory/plot_monocle-partition-plot.pdf" for s in files],params=list(paramspace.instance_patterns))
                ]
    if option in ["minimal","standard","advanced"]:
        outs = outs + [
                identity_dependent_dimplot(paramspace,identity_to_analysis),
                identity_dependent_metrics(paramspace,[x for x in identity_to_analysis if x != "orig.ident"]),
                kraken_predictions(paramspace,taxa,identity_to_analysis),
                selected_gene_plot(paramspace,gene_to_plot,identity_to_analysis)
                ]
    if option in ["clustree","clusteringTree","minimal","standard","advanced"]:
        outs += expand([results_folder + "/" + s + "/{params}/technicals/plot_clustree.pdf" for s in files],params=list(paramspace.instance_patterns))
        outs += expand([results_folder + "/" + s + "/{params}/technicals/plot_nFeature.pdf" for s in files],params=list(paramspace.instance_patterns))
        outs += expand([results_folder + "/" + s + "/{params}/technicals/plot_nCount.pdf" for s in files],params=list(paramspace.instance_patterns))
    if option in ["advanced"]:
        outs += cellchat_plot(paramspace,[x for x in identity_to_analysis if x != "orig.ident"])
        outs += dim_reduction_and_marker_plots(paramspace,identity_to_analysis)
    if option not in ["clustree","clusteringTree","minimal","standard","advanced"]:
        print("Please select a correct option...")

    if not any(b in ["--dry-run", "--dryrun", "-n"] for b in sys.argv):
        write_main_log(files)
        list(map(write_sample_log, files,repeat(paramspace)))
    return outs



if option in ["integration","integrate"]:
    #integration_files = list(pathlib.Path("analyses/processed/").rglob("*.rds"))
    a,b,c=cellsnake_glob_wildcards("analyses/processed/percent_mt~{a}/resolution~{b}/{c}.rds")
    integration_files=expand("analyses/processed/percent_mt~{a}/resolution~{b}/{c}.rds",zip,a=a,b=b,c=c)

    total_samples_to_merge=len(set(c))
    total_rds_files=len(list(zip(a,b,c)))
    total_mt_samples=len(set(zip(a,c)))
    
    #print(total_samples_to_merge)
    #print(total_mt_samples)
    if total_samples_to_merge > 1: # there are at least two samples
        if total_rds_files > total_samples_to_merge:

            integration_files=[]
            if total_mt_samples == total_samples_to_merge:
                print("I detected more than one RDS file per sample, I will select one of them to merge...")
                for i in set(c):
                    integration_files.append(list(pathlib.Path("analyses/processed/").rglob(i + ".rds"))[0])
            else:
                print("There are identical samples with different MT content, I am not sure how to merge them, better to remove some RDS files manually...")
                print(expand("analyses/processed/percent_mt~{a}/resolution~{b}/{c}.rds",zip,a=a,b=b,c=c))
     
    #print(integration_files)
    include: "rules/integration.smk"
    rule all:
        input:
             "analyses_integrated/seurat/" + integration_id + ".rds" if len(integration_files) > 1 else []
    
    if not any(b in ["--dry-run", "--dryrun", "-n"] for b in sys.argv):
        write_main_log([str(i) for i in integration_files])

elif option in ["minimal","standard","clustree","clusteringTree","advanced"] and files:
    include: "rules/seurat.smk"
    include: "rules/microbiome.smk"
    rule all:
        input:
            sample_parameter(paramspace,files)

else:
    print("Please select a correct option or no files detected...")
    pass
