#!/usr/bin/env python 

import os
import sys
import argparse
import seaborn as sns
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument('splicekit_run1', help="Main folder with splicekit run 1")
parser.add_argument('comparison1', help="Name of comparison to use in splicekit run 1")
parser.add_argument('comparison1_fname', help="")
parser.add_argument('splicekit_run2', help="Main folder with splicekit run 2")
parser.add_argument('comparison2', help="Name of comparison to use in splicekit run 2")
parser.add_argument('comparison2_fname', help="")
parser.add_argument('fname_out', help="")
parser.add_argument("-fdr", "--fdr", help="FDR threshold, default no filter", default=None, type=float)
parser.add_argument("-logFC", "--logFC", help="logFC threshold, default=None (do not use as filter)", default=None, type=float)
parser.add_argument("-dpfi", "--dpfi", help="delta(percentage feature inclusion) threshold, default=None (do not use as filter)", default=None, type=float)
parser.add_argument("-intersect", "--intersect", help="Report only features that are hit in both comparisons", action="store_true")
parser.add_argument("-features", "--features", help="Which features to process, comma separated (default: 'exons,junctions,donor_anchors,acceptor_anchors')", default="exons,junctions,donor_anchors,acceptor_anchors")
args = parser.parse_args()

def print_version():
    print("splicecompare v0.3")
    print("FDR threshold = ", args.fdr)
    print("logFC threshold = ", args.logFC)
    print("dpfi threshold = ", args.dpfi)
    print("----")

print_version()

fdr_key = None
data_all = {}
for feature_type in args.features.split(","):
    fout = open(f"{args.fname_out}_{feature_type}.tab", "wt")
    results = {}
    for splicekit_run, comparison, comparison_fname in [(args.splicekit_run1, args.comparison1, args.comparison1_fname), (args.splicekit_run2, args.comparison2, args.comparison2_fname)]:
        f = open(os.path.join(splicekit_run, "results", f"results_edgeR_{feature_type}", f"{comparison_fname}"), "rt")
        header = f.readline().replace("\r", "").replace("\n", "").split("\t")
        r = f.readline()
        while r:
            r = r.replace("\r", "").replace("\n", "").split("\t")
            data = dict(zip(header, r))
            # store all the data
            temp = data_all.get(comparison, {})
            temp[data["feature_id"]] = data
            data_all[comparison] = temp
            if data.get("fdr", None) != None:
                fdr_key = "fdr"
            if data.get("FDR", None) != None:
                fdr_key = "FDR"
            if args.fdr!=None:
                if fdr_key!=None:
                    if float(data[fdr_key])>args.fdr:
                        r = f.readline()
                        continue
            if args.logFC!=None:
                if abs(float(data["logFC"]))<args.logFC:
                    r = f.readline()
                    continue
            if args.dpfi!=None:
                if abs(float(data["delta_pfi"]))<args.dpfi:
                    r = f.readline()
                    continue
            results_comparison = results.get(comparison, {})
            results_comparison[data["feature_id"]] = data
            results[comparison] = results_comparison
            r = f.readline()
        f.close()
    
    list1 = list(results.get(args.comparison1, {}).keys())
    list2 = list(results.get(args.comparison2, {}).keys())

    fout.write("# " + str(len(list1)) + f" {feature_type} detected for {args.comparison1}" + "\n")
    fout.write("# " + str(len(list2)) + f" {feature_type} detected for {args.comparison2}" + "\n")
    fout.write("# intersect = " + str(len(list(set(list2).intersection(set(list1))))) + "\n")

    feature_keys = list(set(list(results.get(args.comparison1, {}).keys()) + list(results.get(args.comparison2, {}).keys())))
    header = ["feature_id", "chr", "strand", "gene_name", f"fdr_{args.comparison1}", f"logFC_{args.comparison1}", f"delta_pfi_{args.comparison1}", f"fdr_{args.comparison2}", f"logFC_{args.comparison2}", f"delta_pfi_{args.comparison2}"]
    fout.write("\t".join(header) + "\n")

    final_results = []
    feature_intersect = list(set(list2).intersection(set(list1))) # list of features for which both comparison1 and comparison2 satisfy condition (FDR<thr, logFC<thr etc)
    for feature_id in feature_keys:
        row = [feature_id]
        data1 = data_all.get(args.comparison1, {}).get(feature_id, {})
        data2 = data_all.get(args.comparison2, {}).get(feature_id, {})
        if args.intersect and feature_id not in feature_intersect:
            continue
        if data1=={} or data2=={}:
            continue
        data_common = data1 if data1!={} else data2
        row.append(data_common["chr"])
        row.append(data_common["strand"])
        row.append(data_common["gene_name"])
        
        for data in [data1, data2]:
            row.append(data.get(fdr_key, ""))
            row.append(data.get("logFC", ""))
            row.append(data.get("delta_pfi", ""))

        combined_logFC = abs(float(data1.get("logFC", 0))) + abs(float(data2.get("logFC", 0)))
        final_results.append((combined_logFC, row))
    final_results.sort(reverse=True)
    for combined_logFC, row in final_results:
        fout.write("\t".join(row) + "\n")
    fout.close()