#!python
# coding=utf-8

#  PINTS: Peak Identifier for Nascent Transcripts Sequencing
#  Copyright (c) 2019-2021 Li Yao at the Yu Lab.
#
#  This program is free software: you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation, either version 3 of the License, or
#  (at your option) any later version.
#
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.

import argparse
import datetime
import gzip
import logging
import os
import sys
import warnings
from multiprocessing import Pool

logging.basicConfig(format="%(name)s - %(asctime)s - %(levelname)s: %(message)s",
                    datefmt="%d-%b-%y %H:%M:%S",
                    level=logging.INFO,
                    handlers=[
                        logging.StreamHandler()
                    ])
logger = logging.getLogger("PINTS - Caller")

try:
    import numpy as np
    import pandas as pd
    import scipy
    import pysam
    import pybedtools
    from scipy.stats import poisson, binom_test, probplot, uniform, nbinom
    from scipy.signal import find_peaks, peak_widths
    from scipy.ndimage import gaussian_filter1d
    from pybedtools import BedTool
    from statsmodels.stats.multitest import multipletests
    from pints.stats_engine import Poisson, ZIP, NegativeBinomial, ZINB, pval_dist, get_rank, bgIQR, pkIQR, \
        independent_filtering
    from pints.io_engine import get_read_signal, get_coverage, get_coverage_bw, log_assert, normalize_using_input, \
        index_bed_file, peak_bed_to_gtf
    from pints import __version__
except ImportError as e:
    missing_package = str(e).replace("No module named '", "").replace("'", "")
    logger.error("Please install %s first!" % missing_package)
    sys.exit(-1)

housekeeping_files = []
COMMON_HEADER = ('chromosome', 'start', 'end', 'name', 'padj', 'strand', 'reads',
                 'pval', 'mu_0', 'pi_0', 'mu_1', 'pi_1', 'ler_1', 'ler_2', 'ler_3',
                 'summit', 'summit_val')
stat_tester = None
iqr_obj = None


def handle_exception(exc_type, exc_value, exc_traceback):
    """
    Handler for exception

    Parameters
    ----------
    exc_type :
    exc_value :
    exc_traceback :

    Returns
    -------

    Refs
    ----
    https://stackoverflow.com/questions/6234405/logging-uncaught-exceptions-in-python/16993115#16993115
    """
    if issubclass(exc_type, KeyboardInterrupt):
        sys.__excepthook__(exc_type, exc_value, exc_traceback)
        return

    logger.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))


def run_command(cmd, repress_log=False):
    """
    Run command

    Parameters
    ----------
    cmd : str
        command
    repress_log : bool
        When it's set to False, if the command failed, the log will not be wrote to logger.

    Returns
    -------
    stdout : str
        Stdout output
    stderr : str
        Stderr output for the child process
    return_code : int
        Exit status of the child process
    """
    from subprocess import Popen, PIPE
    p = Popen(cmd, shell=True, stderr=PIPE, stdout=PIPE)
    stdout, stderr = p.communicate()
    stderr = stderr.decode("utf-8")
    stdout = stdout.decode("utf-8")
    if not repress_log:
        if p.returncode != 0:
            logger.error("Failed to run command %s" % cmd)
    return stdout, stderr, p.returncode


def runtime_check():
    """
    Runtime check, make sure all dependent tools are callable

    Parameters
    ----------

    Returns
    -------

    """
    import shutil
    if sys.platform == "win32":
        logger.warning("No test had performed on Windows, so it might be buggy.")
    dependent_tools = ("bgzip", "tabix", "bedtools")
    for tool in dependent_tools:
        full_path = shutil.which(tool)
        if full_path is None:
            logger.error("Required tool %s is not callable" % tool)
            exit(1)


def merge_intervals(intervals, distance=0):
    """
    Merge intervals

    Parameters
    ----------
    intervals : tuple/list
        List / tuple of interval tuples
    distance : int
        Maximum distance between features allowed for features to be merged.
        Default is 0. That is, overlapping and/or book-ended features are merged.

    Returns
    -------
    merged_intervals : list
        Tuple of merged intervals

    Refs
    ----
        https://www.geeksforgeeks.org/merging-intervals/
    """
    log_assert(distance >= 0, "distance need to be >= 0", logger)
    s = sorted(intervals, key=lambda t: t[0])
    m = 0
    for t in s:
        if t[0] > s[m][1] + distance:
            m += 1
            s[m] = t[:2]
        else:
            # consider intervals
            # ((6, 8), (1, 9), (2, 4), (4, 7))
            # if we don't add an extra check
            # the final result will be (1, 8) instead of (1, 9)
            if s[m][1] <= t[1]:
                s[m] = [s[m][0], t[1]]
    return s[:m + 1]


def sliding_window(chromosome_coverage, window_size, step_size):
    """
    Generate sliding windows

    Parameters
    ----------
    chromosome_coverage : array-like
        0-based per base coverage array for a certain chromosome
    window_size : int
        Window size for scanning
    step_size : int
        Step size for scanning

    Yields
    ------
    window : int
        Read counts in this window
    start : int
        0-based start coordinate of this window
    end : int
        0-based end coordinate of this window
    """
    if step_size < 1:
        logger.error("step_size must >= 1")
        raise ValueError("step_size must >= 1")
    if len(chromosome_coverage) < 1:
        logger.error("chromosome_coverage must >= 1")
        raise ValueError("chromosome_coverage must >= 1")

    total_bins = np.floor(chromosome_coverage.shape[0] / step_size - window_size / step_size + 1).astype(
        int)
    start = 0
    end = window_size
    for _ in range(total_bins):
        window = np.sum(chromosome_coverage[start:end])
        yield window, (start, end)
        start += step_size
        end = start + window_size


def check_window(coord_start, coord_end, mu_peak, var_peak, pi_peak, chromosome_coverage, peak_in_bg_threshold,
                 mu_bkg_minimum, sp_bed_handler, chromosome_name, fdr_target, cache, small_window_threshold=5,
                 flanking=(10000, 5000, 1000), disable_ler=False):
    """
    Calculate p-value for a peak

    Parameters
    ----------
    coord_start : int
        0-based start coordinate
    coord_end : int
        0-based end coordinate
    mu_peak : float
        mu_mle of the peak
    var_peak : float or None
        var_mle of the peak, can be None if not evaluated
    pi_peak : float
        pi_mle of the peak
    chromosome_coverage : array-like
        0-based per base coverage array for a certain chromosome
    peak_in_bg_threshold : float
        Candidate peaks with density higher than this value will be removed from the local environment
    mu_bkg_minimum : float
        minimum mu for background
    sp_bed_handler : pysam.TabixFile
        pysam.TabixFile object for subpeak bed file
    chromosome_name : str
        name of the chromosome/contig to call peaks
    fdr_target : float
        FDR target
    cache : dict
        cache for IQR, leave it as it is
    small_window_threshold : int
        Candidate peaks with lengths shorter than this value will be skipped
    flanking : tuple
        Lengths of local environment that this function will check
    disable_ler : bool
        Disable local environment refinement, by default, False

    Returns
    -------
    p_value : float
        p_value for the peak
    window_value: int
        read counts in this window
    mu_0 : float
        mu for local env
    pi_0: float
        pi for local env
    ler_counts : int
        # of local peaks masked by LER
    """
    selected_window = chromosome_coverage[coord_start:coord_end]
    window_value = np.sum(selected_window)
    if coord_end - coord_start < small_window_threshold \
            or window_value == 0:
        return 1., window_value, 0, 0, (0, 0, 0)
    flanking = np.asarray(flanking, dtype=int) // 2
    mus = []
    variances = []
    pis = []
    ler_counts = []
    # cache = dict()
    for k, f in enumerate(flanking):
        # cache = dict()
        qsl = coord_start - f
        qel = coord_start
        qsl = qsl if qsl >= 0 else 0
        qsr = coord_end
        qer = coord_end + f
        bg, x = iqr_obj.remove_peaks_in_local_env(stat_tester=stat_tester, bed_handler=sp_bed_handler,
                                                  chromosome=chromosome_name, query_start_left=qsl,
                                                  query_end_left=qel, query_start_right=qsr,
                                                  query_end_right=qer, small_window_threshold=small_window_threshold,
                                                  peak_in_bg_threshold=peak_in_bg_threshold,
                                                  coverage_info=chromosome_coverage,
                                                  fdr_target=fdr_target, cache=cache, disable_ler=disable_ler)

        mu_, var_, pi_, _, _ = stat_tester.fit(bg)
        mus.append(mu_)
        variances.append(var_)
        pis.append(pi_)
        ler_counts.append(x)

    mu_0 = np.mean(mus)  # mus[index]
    var_0 = np.mean(variances)
    pi_0 = np.mean(pis)  # pis[index]
    if mu_bkg_minimum is not None and mu_0 < mu_bkg_minimum:
        mu_0 = mu_bkg_minimum

    pvalue = stat_tester.sf(mu_peak, var_peak, pi_peak, mu_0, var_0, pi_0)

    return pvalue, window_value, mu_0, pi_0, ler_counts


def quasi_max_score_segment(candidates, donor_tolerance, receptor_tolerance, ce_trigger, max_distance):
    """
    Max score segment algorithm to join adjacent sub peaks/seeds

    Parameters
    ----------
    candidates : list or tuple
        list of candidate peaks, each peak is also a list with 4 elements:
            start
            end
            read counts
            density
    donor_tolerance : float
        Donor tolerance in best score segments
    receptor_tolerance : float
        Receptor tolerance in best score segments
    ce_trigger : int
        Trigger for receptor tolerance checking
    max_distance : int
        Max distance allowed to join two sub peaks/seeds

    Returns
    -------
    fwd_search :
        Merged peaks from forward search
    rev_search :
        Merged peaks from reverse search
    """
    fwd_search = []
    rev_search = []
    # forward search
    for k, c in enumerate(candidates):
        if k < len(candidates) - 1:
            new_total = c[2] + candidates[k + 1][2]
            new_density = new_total / (candidates[k + 1][1] - c[0])
            if new_density >= donor_tolerance * c[3]:
                distance_check = c[1] - c[0] < ce_trigger or candidates[k + 1][0] - c[1] > max_distance
                if distance_check:  # or new_density < (receptor_tolerance * candidates[k + 1][3]):
                    # fwd_search.append(c)
                    continue
                merged = (c[0], candidates[k + 1][1], new_total, new_density)
                fwd_search.append(merged)
            else:
                fwd_search.append(c)
        else:
            fwd_search.append(c)
    # reverse search
    for k in range(len(candidates) - 1, -1, -1):
        c = candidates[k]
        if k > 0:
            new_total = c[2] + candidates[k - 1][2]
            new_density = new_total / (c[1] - candidates[k - 1][0])
            if new_density >= donor_tolerance * c[3]:
                distance_check = c[1] - c[0] < ce_trigger or c[0] - candidates[k - 1][1] > max_distance
                if distance_check:  # or new_density < (receptor_tolerance * candidates[k - 1][3]):
                    continue
                merged = (candidates[k - 1][0], c[1], new_total, new_density)
                rev_search.append(merged)
            else:
                rev_search.append(c)
        else:
            rev_search.append(c)
    return fwd_search, rev_search


def merge_covs(covs, chromosome_of_interest):
    """
    Merge coverage tracks

    Parameters
    ----------
    covs : list of dicts
        List of coverage dicts for each rep
    chromosome_of_interest : str
        Name of the chromosome/contig to be working on

    Returns
    -------
    merged_coverage : np.ndarray
        merged coverage tracks
    """
    merged_coverage = None
    for cd in covs:
        cov = np.load(cd[chromosome_of_interest])
        if merged_coverage is None:
            merged_coverage = np.zeros(cov.shape[0], dtype=np.int32)

        merged_coverage += cov
        del cov
    return merged_coverage


def cut_peaks_dry_run(annotation_gtf, pl_cov_files, mn_cov_files, highlight_chromosome="chr1",
                      output_diagnostics=False, save_to=None):
    """
    Select optimal alpha values to join sub peaks/seeds

    Parameters
    ----------
    annotation_gtf : str
        Gene annotation gtf file
    pl_cov_files : list of dicts
        List of coverage dicts for each rep (forward strand)
    mn_cov_files : list of dicts
        List of coverage dicts for each rep (reverse strand)
    highlight_chromosome : str
        Name of the chromosome/contig to be working on
    output_diagnostics : bool
        Write out diagnostics
    save_to : None or str
        Name of the output
    Returns
    -------
    selected_threshold : float
        Optimal threshold
    """
    from pints.io_engine import parse_gtf
    annotations = parse_gtf(annotation_gtf)
    transcripts_pc = annotations.loc[np.logical_and(np.logical_and(annotations.feature == "transcript",
                                                                   annotations.gene_type == "protein_coding"),
                                                    annotations.seqname == highlight_chromosome), :]
    log_assert(transcripts_pc.shape[0] > 0,
               "Cannot parse any annotations for protein-coding genes from provided annotations", logger)
    pct_bed = transcripts_pc.loc[:, ("seqname", "start", "end", "transcript_id", "gene_name", "strand", "gene_id")]
    pct_bed.start -= 1
    pct_tss = pct_bed.loc[:, ("seqname", "start", "end", "transcript_id", "gene_name", "strand")]
    pct_tss["start"] = pct_tss.apply(lambda x: x["start"] if x["strand"] == "+" else x["end"] - 1, axis=1)
    pct_tss["end"] = pct_tss.apply(lambda x: x["start"] + 1 if x["strand"] == "+" else x["end"], axis=1)
    pct_tss_bed = BedTool.from_dataframe(pct_tss)
    pct_tss["start"] -= 200
    pct_tss["end"] += 200
    tss_window = BedTool.from_dataframe(pct_tss).sort().merge(s=True)
    logger.info("%d annotated TSSs loaded" % pct_tss.shape[0])
    pl_cov = merge_covs(pl_cov_files, highlight_chromosome)
    mn_cov = merge_covs(mn_cov_files, highlight_chromosome)

    def cp_atom(coverage_track, abs_start, abs_end, donor_tolerance, receptor_tolerance=0.1, ce_trigger=3):
        starts = []
        ends = []
        sub_peaks = cut_peaks(coverage_track[abs_start:abs_end],
                              peak_rel_height=1.,
                              donor_tolerance=donor_tolerance,
                              receptor_tolerance=receptor_tolerance,
                              ce_trigger=ce_trigger)
        for sp in sub_peaks:
            starts.append(sp[0] + abs_start)
            ends.append(sp[1] + abs_start)
        return starts, ends

    search_range = np.linspace(0, 1, 101)
    ambs = np.zeros(search_range.shape[0])
    median_sizes = np.zeros(search_range.shape[0])
    for i, dt in enumerate(search_range):
        predicted_peaks_dicts = []
        for window in tss_window:
            if np.sum(pl_cov[window.start:window.stop]) > 0:
                ss, es = cp_atom(pl_cov, window.start, window.stop, donor_tolerance=dt)
                for s, e in zip(ss, es):
                    predicted_peaks_dicts.append({"seqname": highlight_chromosome,
                                                  "start": s, "end": e, "name": ".", "score": ".", "strand": "+"})
            if np.sum(mn_cov[window.start:window.end]) > 0:
                ss, es = cp_atom(mn_cov, window.start, window.stop, donor_tolerance=dt)
                for s, e in zip(ss, es):
                    predicted_peaks_dicts.append({"seqname": highlight_chromosome,
                                                  "start": s, "end": e, "name": ".", "score": ".", "strand": "-"})
        predicted_peaks = BedTool.from_dataframe(pd.DataFrame(predicted_peaks_dicts))
        tmp_result = predicted_peaks.intersect(pct_tss_bed, c=True, s=True)
        tmp_df = tmp_result.to_dataframe(names=("seqname", "start", "end", "name", "score", "strand", "hits"))
        ambiguous_rate = sum(tmp_df["hits"] > 1) / sum(tmp_df["hits"] > 0)
        median_sizes[i] = np.median(tmp_df["end"] - tmp_df["start"])
        ambs[i] = ambiguous_rate

    smoothed_median_sizes = gaussian_filter1d(median_sizes, 1)
    smoothed_ambs = gaussian_filter1d(ambs, 1)
    fod_size = np.gradient(smoothed_median_sizes)
    upper_bound_from_size = np.argmax(fod_size)
    fod_ar = np.abs(np.gradient(smoothed_ambs))[upper_bound_from_size:]
    selected_locus = np.argmin(fod_ar) + upper_bound_from_size
    selected_threshold = search_range[selected_locus]

    if output_diagnostics and save_to is not None:
        import matplotlib.pyplot as plt
        fig, axs = plt.subplots(2, 1)
        axs[0].plot(search_range, smoothed_median_sizes, label="Median size of elements")
        axs[0].axvline(search_range[upper_bound_from_size])

        axs[1].plot(search_range, smoothed_ambs, label="Ambiguous rate")
        axs[1].axvline(search_range[selected_locus])
        plt.tight_layout()
        plt.savefig(save_to, transparent=True, bbox_inches="tight")
        plt.close()

    return selected_threshold


def cut_peaks(window, peak_rel_height, donor_tolerance, receptor_tolerance, ce_trigger, max_distance=30):
    """
    Cut peaks from the given window

    Parameters
    ----------
    window : array-like
        Per base read counts / coverage
    peak_rel_height : float, optional
        Used for calculation of the peaks width, thus it is only used if width is given.
    donor_tolerance : float
        From sub peak seeking for merging, the new density should be larger than dt*prev_d
    receptor_tolerance : float
        From sub peak seeking being merged, the new density should be larger than rt*prev_d
    ce_trigger : int
        Sub peak narrower than cet will trigger receptor tolerance check
    max_distance : int
        max distance between two subpeaks to be joined, by default, 20
    Returns
    -------
    merged_intervals : list
        List of tuples of merged intervals [(start_1, end_1), ... , (start_n, end_n)]
    """
    peaks, _ = find_peaks(window, rel_height=peak_rel_height)
    widths, cor_heights, starts, ends = peak_widths(window, peaks, rel_height=peak_rel_height)
    intervals = []
    for k, start in enumerate(starts):
        intervals.append((int(start), int(ends[k])))
    mi = merge_intervals(intervals=intervals, distance=1)
    candidates = []
    for m in mi:
        events = 0
        for i in range(m[0], m[1] + 1):
            events += window[i]
        candidates.append((m[0], m[1], events, events / (m[1] - m[0])))

    f, r = quasi_max_score_segment(candidates=candidates, donor_tolerance=donor_tolerance,
                                   receptor_tolerance=receptor_tolerance, ce_trigger=ce_trigger,
                                   max_distance=max_distance)
    f.extend(r)
    final = merge_intervals(f, distance=1)
    return final


def check_window_chromosome(rc_file, output_file, strand_sign, chromosome_name, subpeak_file, fdr_target,
                            read_counts_threshold, small_peak_threshold=5, min_mu_percent=0.1, disable_ler=False):
    """
    Evaluate windows on a chromosome

    Parameters
    ----------
    rc_file : str
        Path to numpy saved read coverage info
    output_file : str
        Path to store outputs
    strand_sign : str
        Strand of windows
    chromosome_name : str
        Name of this chromosome
    subpeak_file : str
        File containing info about all subpeaks
    fdr_target : float
        fdr target
    read_counts_threshold : int
        Min number of reads for a peak to be considered
    small_peak_threshold : int
        Peaks shorter than this threshold will be evaluated by Poisson instead of ZIP, by default, 5
    min_mu_percent : float
        Local backgrounds smaller than this percentile among all peaks will be replaced. By default, 0.1.
    disable_ler : bool
        Disable LER. By default, False.
    Returns
    -------
    result_df : pd.DataFrame
        Window bed in dataframe
    """
    global housekeeping_files
    # ler_cache = dict()
    per_base_cov = np.load(rc_file, allow_pickle=True)
    subpeak_bed = output_file.replace(".bed", "_subpeaks_%s.bed" % chromosome_name)
    bins = []
    spb_fh = open(subpeak_bed, "w")
    all_peak_mus = []
    try:
        with gzip.open(subpeak_file, "rt") as peak_obj:
            for line in peak_obj:
                items = line.strip().split("\t")
                start = int(items[1])
                end = int(items[2])
                peak_region = per_base_cov[start:end]
                window_value = peak_region.sum()
                n_start_sizes = sum(peak_region > 0)  # if n start sizes is smaller than 3, then ZIP shouldn't be used
                peak_len = end - start
                if peak_len < small_peak_threshold \
                        or window_value == 0 or n_start_sizes <= 3:
                    mu_peak = window_value / peak_len
                    var_peak = np.var(window_value)
                    pi_peak = 0
                else:
                    if peak_region.shape[0] > 10:
                        e_peak = peak_region
                    else:
                        e_peak = peak_region

                    mu_peak, var_peak, pi_peak, _, _ = stat_tester.fit(e_peak)
                    all_peak_mus.append(mu_peak)
                x = np.argmax(peak_region)
                summit_coord = start + x

                spb_fh.write("%s\t%d\t%d\t%s\t%f\t%s\t%s\t%f\t%d\t%d\n" % (
                    chromosome_name, start, end, items[3], mu_peak, var_peak, strand_sign, pi_peak, summit_coord,
                    peak_region[x]))
        spb_fh.close()

        index_bed_file(subpeak_bed, logger=logger)

        bed_handler = pysam.TabixFile(subpeak_bed + ".gz")
        if len(all_peak_mus) == 0:
            logger.warning("No non-trivial peak was detected on chromosome %s" % chromosome_name)
            peak_threshold = 1
            bkg_mu_threshold = 0
        else:
            peak_threshold = 1
            bkg_mu_threshold = np.quantile(all_peak_mus, min_mu_percent)
            logger.info("Minimum mu in local environment %f (%s)" % (bkg_mu_threshold, chromosome_name))
        global_cache = dict()
        with gzip.open(subpeak_bed + ".gz", "rt") as peak_obj:
            for peak in peak_obj:
                candidate_peak = peak.split("\t")
                peak_start = int(candidate_peak[1])
                peak_end = int(candidate_peak[2])
                peak_id = candidate_peak[3]
                peak_mu = float(candidate_peak[4])
                peak_var = float(candidate_peak[5])
                peak_pi = float(candidate_peak[7])
                peak_summit = int(candidate_peak[8])
                peak_summit_val = int(candidate_peak[9])
                pval, wv, mu_bg, pi_bg, lerc = check_window(coord_start=peak_start, coord_end=peak_end, mu_peak=peak_mu,
                                                            var_peak=peak_var, pi_peak=peak_pi,
                                                            chromosome_coverage=per_base_cov,
                                                            peak_in_bg_threshold=peak_threshold,
                                                            mu_bkg_minimum=bkg_mu_threshold, sp_bed_handler=bed_handler,
                                                            chromosome_name=chromosome_name,
                                                            fdr_target=fdr_target,
                                                            cache=global_cache, disable_ler=disable_ler)
                if wv > read_counts_threshold:
                    bins.append(
                        (chromosome_name, peak_start, peak_end, peak_id, pval, wv, mu_bg, pi_bg, peak_mu, peak_pi,
                         peak_summit, peak_summit_val, lerc[0], lerc[1], lerc[2]))
    except TypeError as e:
        logger.error(str(chromosome_name) + "\t" + str(subpeak_file))
        logger.error(e)
    result_df = pd.DataFrame(bins, columns=("chromosome", "start", "end", "name", "pval", "reads",
                                            "mu_0", "pi_0", "mu_1", "pi_1", "summit", "summit_val",
                                            "ler_1", "ler_2", "ler_3"))
    result_df["strand"] = strand_sign
    result_df = result_df.loc[:, ("chromosome", "start", "end", "name", "pval", "strand", "reads",
                                  "mu_0", "pi_0", "mu_1", "pi_1", "summit", "summit_val", "ler_1", "ler_2", "ler_3")]
    return result_df


def peaks_single_strand(per_base_cov, output_file, shared_peak_definitions, strand_sign,
                        **kwargs):
    """
    Calling peaks on one strand

    Parameters
    ----------
    per_base_cov : dict
        Per base cov for available chromosomes
    output_file : str
        Path of output files
    shared_peak_definitions : dict
        Dictionary containing all subpeaks per chromosome
    strand_sign : str
        Strand sign for the data
    **kwargs :

    Returns
    -------
    result_df : str
        Path to a compressed and indexed bed file
    """
    global housekeeping_files
    fn, ext = os.path.splitext(output_file)

    args = []
    for chrom, pbc_npy in per_base_cov.items():
        if shared_peak_definitions[chrom] is None:  # bypass chromosomes without signals
            continue
        sub_peaks_name = output_file.replace(".bed", "_subpeaks_%s.bed" % chrom)
        merged_name = output_file.replace(".bed", "_%s_merged_windows.bed" % chrom)
        args.append((pbc_npy, output_file, strand_sign, chrom, shared_peak_definitions[chrom], kwargs["fdr_target"],
                     kwargs["read_counts_threshold"], kwargs["small_peak_threshold"], kwargs["min_mu_percent"],
                     kwargs["disable_ler"]))
        housekeeping_files.append(merged_name)
        housekeeping_files.append(sub_peaks_name + ".gz")
        housekeeping_files.append(sub_peaks_name + ".gz.tbi")

    if kwargs["thread_n"] == 1:
        # for debugging
        sub_dfs = []
        for arg_i in args:
            sub_dfs.append(
                check_window_chromosome(*arg_i)
            )
    else:
        with Pool(kwargs["thread_n"]) as pool:
            sub_dfs = pool.starmap(check_window_chromosome, args)

    sub_dfs = [sdf for sdf in sub_dfs if sdf is not None]
    log_assert(len(sub_dfs) > 0, "No signal found across all chromosomes!", logger)
    tmp_df = pd.concat(sub_dfs)

    if kwargs["output_diagnostics"]:
        tmp_df.to_csv(output_file.replace(".bed", "_debug.csv"), index=False)
    big_peaks_probe = tmp_df.end - tmp_df.start > kwargs["small_peak_threshold"]
    small_peaks_probe = tmp_df.end - tmp_df.start <= kwargs["small_peak_threshold"]
    lamb_global = tmp_df.loc[big_peaks_probe, "mu_1"].quantile(kwargs["top_peak_threshold"])
    lamb_global = lamb_global if lamb_global >= 1 else 1

    if kwargs["output_diagnostics"]:
        import matplotlib.pyplot as plt
        tx = np.arange(0, 1, 0.01)
        qs = [tmp_df.loc[big_peaks_probe, "mu_1"].quantile(x) for x in tx]
        plt.plot(tx, qs)
        plt.xlabel("Quantile")
        plt.ylabel("Peak density")
        plt.tight_layout()
        plt.savefig(fn + "_small_peak_threshold.pdf", transparent=True, bbox_inches="tight")
        plt.close()
    logger.info("Lambda for small peaks: %f" % lamb_global)
    inflated_small_peaks = np.sum(small_peaks_probe)

    small_pois_dict = dict()

    def cached_pois(x):
        if x["end"] - x["start"] > kwargs["small_peak_threshold"]:
            return x["pval"]
        else:
            expected_counts = lamb_global * (x["end"] - x["start"])
            k = "{reads}-{tr}".format(reads=x["reads"], tr=expected_counts)
            if k not in small_pois_dict:
                small_pois_dict[k] = poisson.sf(x["reads"], expected_counts)
            return small_pois_dict[k]

    is_disable_small = kwargs.get("disable_small", False)
    if is_disable_small:
        tmp_df = tmp_df.loc[big_peaks_probe, :]
        tmp_df_sm = None
    else:
        tmp_df["pval"] = tmp_df.apply(cached_pois, axis=1)

        corrected_small_peaks = np.sum(np.logical_and(tmp_df["pval"] < kwargs["fdr_target"],
                                                      small_peaks_probe))
        logger.info("Significant small peaks after correction: %d (%d)" % (corrected_small_peaks, inflated_small_peaks))
        tmp_df_sm = independent_filtering(tmp_df.loc[small_peaks_probe, :], output_to=fn + "_idpf_sm.pdf",
                                          logger=logger, **kwargs)
    if kwargs["output_diagnostics"]:
        pval_dist(tmp_df.loc[tmp_df["end"] - tmp_df["start"] > kwargs["small_peak_threshold"], "pval"],
                  logger=logger,
                  output_diagnostics=kwargs["output_diagnostics"],
                  output_to=fn + "_broad_pval_hist.pdf")
        pval_dist(tmp_df.loc[tmp_df["end"] - tmp_df["start"] <= kwargs["small_peak_threshold"], "pval"],
                  logger=logger,
                  output_diagnostics=kwargs["output_diagnostics"],
                  output_to=fn + "_narrow_peaks_pval_hist.pdf")
        pval_dist(tmp_df["pval"],
                  logger=logger,
                  output_diagnostics=kwargs["output_diagnostics"],
                  output_to=fn + "_pval_hist.pdf")

    # stratified independent filtering
    tmp_df_bg = independent_filtering(tmp_df.loc[tmp_df["end"] - tmp_df["start"] > kwargs["small_peak_threshold"],
                                      :], output_to=fn + "_idpf_bg.pdf", logger=logger, **kwargs)
    if tmp_df_sm is not None:
        result_df = pd.concat([tmp_df_sm, tmp_df_bg])
    else:
        result_df = tmp_df_bg

    result_df = result_df.loc[:, COMMON_HEADER]
    result_df.sort_values(by=['chromosome', 'start'], inplace=True)
    result_df.to_csv(output_file, sep="\t", index=False, header=False)
    index_bed_file(output_file, logger=logger)
    return output_file + ".gz"


def merge_opposite_peaks(sig_peak_bed, peak_candidate_bed, divergent_output_bed, bidirectional_output_bed,
                         singleton_bed, fdr_target, stringent_only=False, **kwargs):
    """
    Merge peaks on the opposite strand and generate divergent peak pairs

    Parameters
    ----------
    sig_peak_bed : str
        Path to bed file which contains significant peaks
    peak_candidate_bed : str
        Path to bed file which contains all candidate peaks on the opposite strand
    divergent_output_bed : str
        Path to output which stores divergent peaks
    bidirectional_output_bed : str
        Path to output which stores bidirectional peaks (divergent / convergent)
    singleton_bed : str
        Path to output which stores significant peaks which failed to pair
    fdr_target : float
        FDR target
    stringent_only : bool
        Set it to True if you only want to keep significant pairs (both peaks needs to be significant)

    **kwargs :
        close_threshold : int
            Distance threshold for two peaks (on opposite strands) to be merged
        min_len_opposite_peaks : int
            Minimum length requirement for peaks on the opposite strand to be paired,
            set it to 0 to loose this requirement
    Returns
    -------

    """
    tbx = pysam.TabixFile(peak_candidate_bed)
    fh = open(sig_peak_bed, "r")
    div_fh = open(divergent_output_bed, "w")
    bid_fh = open(bidirectional_output_bed, "w")
    sfp_fh = open(singleton_bed, "w")  # singletons failed to pair
    close_threshold = kwargs.get("close_threshold", 300)
    min_len_opposite_peaks = kwargs.get("min_len_opposite_peaks", 0)
    for nr, line in enumerate(fh):
        items = line.strip().split("\t")
        start = int(items[1])
        end = int(items[2])
        current_summit = int(items[-2])
        current_summit_val = int(items[-1])
        # allow overlapping
        if items[5] == "+":
            query_start = start - close_threshold
            query_start = query_start if query_start >= 0 else 0
            query_end = end
        else:
            query_start = start
            query_end = end + close_threshold

        opposite_start = np.nan
        opposite_end = np.nan
        opposite_pval = np.nan
        opposite_qval = np.nan
        opposite_sum = 0
        opposite_starts = []
        opposite_ends = []
        opposite_qvals = []
        opposite_pvals = []
        opposite_vals = []
        opposite_summits = []
        opposite_summit_vals = []
        # since windows on each strand have been merged,
        # so here I expect the following iter returns at
        # most two records
        try:
            query_start = query_start if query_start >= 0 else 0
            for hit in tbx.fetch(items[0], query_start, query_end, parser=pysam.asTuple()):
                hit_start = int(hit[1])
                hit_end = int(hit[2])
                hit_score = float(hit[4])
                hit_reads = float(hit[6])  # in case the read counts had been normed
                opposite_summit = int(hit[-2])
                opposite_summit_val = int(hit[-1])
                # filter peaks on the other strand which are shorter than a threshold
                if min_len_opposite_peaks > 0 and hit_end - hit_start < min_len_opposite_peaks:
                    continue
                if stringent_only and float(hit[7]) > fdr_target:
                    continue
                opposite_starts.append(hit_start)
                opposite_ends.append(hit_end)
                opposite_qvals.append(hit_score)
                opposite_pvals.append(float(hit[7]))
                opposite_vals.append(hit_reads)
                opposite_summits.append(opposite_summit)
                opposite_summit_vals.append(opposite_summit_val)
            if len(opposite_pvals) > 0:
                index = np.argmin(opposite_pvals)
                opposite_start = opposite_starts[index]
                opposite_end = opposite_ends[index]
                opposite_pval = opposite_pvals[index]
                opposite_qval = opposite_qvals[index]
                opposite_summit = int(opposite_summits[index])
                opposite_summit_val = int(opposite_summit_vals[index])
                opposite_sum = sum(opposite_vals[:index + 1])
        except ValueError as err:
            logger.warning("No peak candidate among %s:%d-%d\n%s" % (items[0], query_start, query_end, err))
        if opposite_start is np.nan:
            sfp_fh.write(line)
        else:
            items.extend((str(opposite_start), str(opposite_end), str(opposite_pval), str(opposite_sum)))
            coords = (int(items[1]), int(items[2]), opposite_start, opposite_end)

            if items[5] == "+":
                fwd_summit = current_summit
                rev_summit = opposite_summit
            else:
                fwd_summit = opposite_summit
                rev_summit = current_summit
                t = current_summit_val
                current_summit_val = opposite_summit_val
                opposite_summit_val = t

            tre_start = min(coords)
            tre_end = max(coords)
            if opposite_qval < fdr_target:
                pairing_confidence = "Stringent(qval)"
            elif opposite_pval < fdr_target:
                pairing_confidence = "Stringent(pval)"
            else:
                pairing_confidence = "Relaxed"
            if tre_end - tre_start > kwargs.get("div_size_min", 0):
                candidate_values = (items[0], str(tre_start), str(tre_end), ".",
                                    items[4], items[5], str(float(items[6]) + opposite_sum), items[1],
                                    items[2], str(opposite_start), str(opposite_end), str(fwd_summit),
                                    str(current_summit_val), str(rev_summit), str(opposite_summit_val),
                                    pairing_confidence + "\n")

                bid_fh.write("\t".join(candidate_values))
                if fwd_summit - rev_summit >= kwargs.get("summit_dist_min", 0):
                    div_fh.write("\t".join(candidate_values))
            else:
                sfp_fh.write(line)
    fh.close()
    bid_fh.close()
    div_fh.close()
    sfp_fh.close()


def housekeeping(save_to):
    """
    Delete intermediate files

    Returns
    -------

    """
    global housekeeping_files
    try:
        for hf in housekeeping_files:
            if os.path.exists(hf):
                os.remove(hf)
        pybedtools.cleanup(remove_all=True)
        import glob
        for f in glob.glob(os.path.join(save_to, "pybedtools.*.tmp")):
            os.remove(f)
    except Exception as e:
        logger.warning(str(e))


def show_parameter_info(input_bam, output_dir, output_prefix, thread_n, model, iqr_strategy, stringent_pairs_only,
                        **kwargs):
    """
    Show parameters

    Parameters
    ----------
    input_bam : str
        Path to the input
    output_dir : str
        Path to the output dir
    output_prefix : str
        Output prefix
    thread_n : int
        Number of threads
    model : str or None
        Name of the model
    iqr_strategy : str
        IQR strategy
    stringent_pairs_only : bool
        Whether Relaxed pairs should be included in bidirectional output
    kwargs

    Returns
    -------

    """
    args = kwargs.copy()
    bam_parser = args.pop("bam_parser")
    input_pl_bw = args.pop("bw_pl")
    input_mn_bw = args.pop("bw_mn")
    control_bam = args.pop("ct_bam")
    control_pl_bw = args.pop("ct_bw_pl")
    control_mn_bw = args.pop("ct_bw_mn")
    logger.info("Command")
    logger.info(" ".join(sys.argv))
    logger.info("Parameters")
    if input_bam is not None:
        logger.info("input_bam(s): {input} ({parser})".format(input=" ".join(input_bam), parser=bam_parser))
    else:
        logger.info("input_pl_bw(s): {input}".format(input=" ".join(input_pl_bw)))
        logger.info("input_mn_bw(s): {input}".format(input=" ".join(input_mn_bw)))
    if control_bam is not None:
        logger.info("ct_bam(s): {input} ({parser})".format(input=control_bam, parser=bam_parser))
    elif control_pl_bw is not None and control_mn_bw is not None:
        logger.info("ct_bw_pl(s): {input}".format(input=" ".join(control_pl_bw)))
        logger.info("ct_bw_mn(s): {input}".format(input=" ".join(control_mn_bw)))
    logger.info("output_dir: %s" % output_dir)
    logger.info("output_prefix: %s" % output_prefix)
    logger.info("thread_n: %s" % thread_n)
    logger.info("model: %s" % model)
    logger.info("IQR strategy: %s" % iqr_strategy)
    logger.info("Stringent pairing strategy: %s" % stringent_pairs_only)
    for k, v in args.items():
        logger.info("%s: %s" % (k, v))


def unified_element_definition(coverage_dict, chromosome_of_interest, strand_sign, output_file, peak_rel_height=1,
                               window_size=100, step_size=100, ce_donor=1.0, ce_receptor=0.1, ce_trigger=3):
    """
    Unified element boundary definition
    If multiple replicates are present, then this function will merge all signal tracks together

    Parameters
    ----------
    coverage_dict : list of dicts
        List of coverage dicts for each rep
    chromosome_of_interest : str
        Name of the chromosome/contig to be working on
    strand_sign : str
        Sign of strand to be added in the bed file, + or -
    output_file : str
        Prefix (including path) for outputs
    peak_rel_height : float
        Relative height for a spot to be considered as a sub peak/seed, by default, 1.
    window_size : int
        Window size, by default, 100
    step_size : int
        Step size, by default, 100 (non-overlap)
    ce_donor : float
        Donor tolerance in best score segments
    ce_receptor : float
        Receptor tolerance in best score segments
    ce_trigger : int
        Trigger for receptor tolerance checking

    Returns
    -------
    subpeak_bed : str or None
        file name for the subpeak bed file, or None if there's no peak
    """
    subpeak_bed = output_file.replace(".bed", "_subpeaks_%s.bed" % chromosome_of_interest)
    bins = []
    merged_coverage = merge_covs(coverage_dict, chromosome_of_interest)

    for window, coord in sliding_window(merged_coverage, window_size=window_size, step_size=step_size):
        if window > 0:  # no reads in the bin
            bins.append((chromosome_of_interest, coord[0], coord[1], window))

    logger.info("Before merging, there are %d windows on %s" % (len(bins), chromosome_of_interest))
    tmp_df = pd.DataFrame(bins, columns=("chromosome", "start", "end", "reads"))
    tmp_df["name"] = "."
    tmp_df["strand"] = strand_sign
    tmp_df = tmp_df.loc[:, ("chromosome", "start", "end", "name", "reads", "strand")]
    if tmp_df.shape[0] == 0:  # no hit
        return None
    # merge windows in case peaks are split into different windows
    bed_obj = BedTool(tmp_df.to_csv(sep="\t", index=False, header=False), from_string=True)
    bed_obj = bed_obj.merge(c=(4, 5, 6), o=("distinct", "sum", "distinct"))
    merged_windows = bed_obj.to_dataframe(names=("Chromosome", "Start", "End", "Name", "Reads", "Strand"))
    logger.info("After merging, there are %d windows on %s" % (merged_windows.shape[0], chromosome_of_interest))

    spb_fh = open(subpeak_bed, "w")
    index = 1
    for nr, row in enumerate(bed_obj):
        sub_peaks = cut_peaks(merged_coverage[row.start:row.end],
                              peak_rel_height=peak_rel_height,
                              donor_tolerance=ce_donor,
                              receptor_tolerance=ce_receptor,
                              ce_trigger=ce_trigger,
                              )
        for sp in sub_peaks:
            start = sp[0] + row.start
            end = sp[1] + row.start
            peak_region = merged_coverage[start:end]
            summit_coord = start + np.argmax(peak_region)
            spb_fh.write("%s\t%d\t%d\t%s-%d\t%d\n" % (chromosome_of_interest, start, end, chromosome_of_interest,
                                                      index, summit_coord))
            index += 1
    spb_fh.close()

    if index > 1:
        index_bed_file(subpeak_bed, logger=logger)
        return subpeak_bed + ".gz"
    else:
        return None


def inferring_elements_from_other_reps(prefix, n_samples):
    """
    Infer bidirectional/divergent elements by borrowing signals from reps

    Parameters
    ----------
    prefix : str
        Prefix for outputs (including path)
    n_samples : int
        number of samples/reps
    Returns
    -------

    """
    bids = []
    divs = []
    sigs = []
    for rep in range(1, n_samples + 1):
        sample_prefix = prefix + "_%d" % rep
        bid_file = sample_prefix + "_bidirectional_peaks.bed"
        div_file = sample_prefix + "_divergent_peaks.bed"
        single_file = sample_prefix + "_unidirectional_peaks.bed"
        log_assert(os.path.exists(bid_file), "Cannot locate bidirectional output %s" % bid_file, logger)
        log_assert(os.path.exists(div_file), "Cannot locate divergent output %s" % div_file, logger)
        log_assert(os.path.exists(single_file), "Cannot locate unidirectional output %s" % single_file, logger)
        bids.append(BedTool(bid_file))
        divs.append(BedTool(div_file))
        sigs.append(BedTool(single_file))
    merged_bids = BedTool.cat(*bids, c=(4, 5, 6,), o=("distinct", "distinct", "distinct",))
    merged_divs = BedTool.cat(*divs, c=(4, 5, 6,), o=("distinct", "distinct", "distinct",))

    for separate_calls, pool, label in zip((bids, divs), (merged_bids, merged_divs), ("bidirectional", "divergent")):
        for i, separate_bed in enumerate(separate_calls):
            not_reported_ele = pool.intersect(separate_bed, v=True).intersect(sigs[i], u=True)
            BedTool.cat(*[separate_bed, not_reported_ele], postmerge=False).sort().saveas(
                prefix + "_%d_%s_peaks.bed" % (i + 1, label))


def peak_calling(input_bam, output_dir=".", output_prefix="pints", filters=(),
                 thread_n=1, model="ZIP", iqr_strategy="bgIQR", stringent_only=False, **kwargs):
    """
    Peak calling wrapper

    Parameters
    ----------
    input_bam : str
        Path to the input bam file
    output_dir : str
        Path to write output
    output_prefix : str
        Prefix for all outputs
    filters : list or tuple
        List of keywords to filter chromosomes
    thread_n : int
        Max number of sub processes that can be created.
    model : str
        Statistical model to use, can be Poisson, ZIP, NB or ZINB
    iqr_strategy : str
        IQR strategy to use, can be bgIQR (based on background densities) or pkIQR (based on peak densities)
    stringent_only : bool
        Set it to True if you only want to keep significant pairs (both peaks needs to be significant)
    kwargs :

    Returns
    -------

    """
    global housekeeping_files
    global stat_tester
    global iqr_obj
    log_assert(os.path.exists(output_dir) and os.path.isdir(output_dir), "Cannot write to {0}".format(output_dir),
               logger)
    pybedtools.set_tempdir(output_dir)
    runtime_check()
    logger.info("PINTS version: {0}".format(__version__))
    show_parameter_info(input_bam, output_dir, output_prefix, thread_n, model, iqr_strategy, stringent_only, **kwargs)

    prefix = os.path.join(output_dir, output_prefix)

    log_assert(model in ("Poisson", "ZIP", "NB", "ZINB"), "Unsupported model", logger)
    if model == "ZIP":
        stat_tester = ZIP()
    elif model == "Poisson":
        stat_tester = Poisson()
    elif model == "NB":
        stat_tester = NegativeBinomial()
    elif model == "ZINB":
        logger.warning("You're using Zero-inflated Negative Binomial model (not suggested)")
        stat_tester = ZINB()
    else:
        logger.error("The model you specified {model} is not supported".format(model=model))

    if iqr_strategy == "bgIQR":
        iqr_obj = bgIQR()
    elif iqr_strategy == "pkIQR":
        iqr_obj = pkIQR()
    else:
        logger.error("The IQR strategy you specified {iqr_strategy} is not supported".format(iqr_strategy=iqr_strategy))

    log_assert(input_bam is not None or (kwargs["bw_pl"] is not None and kwargs["bw_mn"] is not None),
               "You must provide PINTS a BAM file (--bam-file) or two bigwig files (--bw-pl and --bw-mn) "
               "for the experiment", logger)

    if kwargs["bw_pl"] is not None and kwargs["bw_mn"] is not None:
        log_assert(len(kwargs["bw_pl"]) == len(kwargs["bw_mn"]),
                   "If you want to use bigwig files as input, make sure you provide both bws for forward and reverse strand",
                   logger)

    input_coverage_pl = []
    input_coverage_mn = []
    chromosome_coverage_pl = []
    chromosome_coverage_mn = []
    rcs = []  # read counts for experiment
    ircs = []  # read counts for input/control
    chromosomes = set()

    if input_bam is not None:
        log_assert(kwargs["bam_parser"] is not None, "Please specify which type of experiment this data "
                                                     "was generated from with --exp-type", logger)
        for i, bf in enumerate(input_bam):
            log_assert(os.path.exists(bf), "Cannot find input bam file %s" % bf, logger)
            plc, mnc, rc = get_read_signal(input_bam=bf,
                                           loc_prime=kwargs["bam_parser"],
                                           reverse_complement=kwargs["seq_rc"],
                                           output_dir=output_dir,
                                           output_prefix=output_prefix + "_%d" % i,
                                           filters=filters,
                                           **kwargs)
            chromosomes = chromosomes.union(set(plc.keys()))
            chromosomes = chromosomes.union(set(mnc.keys()))
            chromosome_coverage_pl.append(plc)
            chromosome_coverage_mn.append(mnc)
            rcs.append(rc)
            housekeeping_files.extend(plc.values())
            housekeeping_files.extend(mnc.values())
    else:
        log_assert(len(kwargs["bw_pl"]) == len(kwargs["bw_mn"]),
                   "Must provide the same amount of bigwig files for both strands", logger)

        for i, bw_pl in enumerate(kwargs["bw_pl"]):
            log_assert(os.path.exists(bw_pl), "Cannot find bigwig file %s" % bw_pl, logger)
            log_assert(os.path.exists(kwargs["bw_mn"][i]), "Cannot find bigwig file %s" % kwargs["bw_mn"][i], logger)
            plc, mnc, rc = get_coverage_bw(bw_pl=bw_pl, bw_mn=kwargs["bw_mn"][i],
                                           chromosome_startswith=kwargs["chromosome_startswith"],
                                           output_dir=output_dir,
                                           output_prefix=output_prefix + "_%d" % i)
            chromosomes = chromosomes.union(set(plc.keys()))
            chromosomes = chromosomes.union(set(mnc.keys()))
            chromosome_coverage_pl.append(plc)
            chromosome_coverage_mn.append(mnc)
            rcs.append(rc)
            housekeeping_files.extend(plc.values())
            housekeeping_files.extend(mnc.values())

    subpeak_pl_beds = dict()
    subpeak_mn_beds = dict()
    if kwargs["gene_annotation"] is not None:
        log_assert(os.path.exists(kwargs["gene_annotation"]), "Cannot find gene annotation file", logger)
        if kwargs["highlight_chromosome"] in chromosome_coverage_pl[0]:
            kwargs["donor_tolerance"] = cut_peaks_dry_run(kwargs["gene_annotation"], chromosome_coverage_pl,
                                                          chromosome_coverage_mn,
                                                          highlight_chromosome=kwargs["highlight_chromosome"],
                                                          output_diagnostics=kwargs["output_diagnostics"],
                                                          save_to=prefix + "_alpha.pdf")
            logger.info("Override the default for --donor-tolerance with {0}".format(kwargs["donor_tolerance"]))

    for chromosome in chromosomes:
        for target_dict, chrom_coverage, sign, strand_short in zip((subpeak_pl_beds, subpeak_mn_beds),
                                                                   (chromosome_coverage_pl, chromosome_coverage_mn),
                                                                   ("+", "-"), ("pl", "mn")):
            target_dict[chromosome] = unified_element_definition(chrom_coverage, chromosome, sign,
                                                                 prefix + "_{0}.bed".format(strand_short),
                                                                 peak_rel_height=kwargs.get("peak_rel_height", 1),
                                                                 window_size=kwargs.get("window_size", 100),
                                                                 step_size=kwargs.get("step_size", 100),
                                                                 ce_donor=kwargs.get("donor_tolerance", 1),
                                                                 ce_receptor=kwargs.get("receptor_tolerance", 0.1),
                                                                 ce_trigger=kwargs.get("ce_trigger", 3))

            if target_dict[chromosome] is not None:
                housekeeping_files.append(target_dict[chromosome])
                housekeeping_files.append(target_dict[chromosome] + ".tbi")

    if kwargs["ct_bam"] is not None:
        logger.info("Loading control sample")
        for i, bf in enumerate(kwargs["ct_bam"]):
            log_assert(os.path.exists(bf), "Cannot find control bam file %s" % bf, logger)
            iplc, imnc, irc = get_read_signal(input_bam=bf,
                                              loc_prime=kwargs["bam_parser"],
                                              reverse_complement=kwargs["seq_rc"],
                                              output_dir=output_dir,
                                              output_prefix="ct_" + output_prefix + "_%d" % i,
                                              filters=filters,
                                              **kwargs)
            input_coverage_pl.append(iplc)
            input_coverage_mn.append(imnc)
            ircs.append(irc)
            housekeeping_files.extend(iplc.values())
            housekeeping_files.extend(imnc.values())
        logger.info("Control sample loaded")
    elif kwargs["ct_bw_pl"] is not None and kwargs["ct_bw_mn"] is not None:
        logger.info("Loading control sample")
        for i, bw_pl in enumerate(kwargs["ct_bw_pl"]):
            log_assert(os.path.exists(bw_pl), "Cannot find control bw file %s" % bw_pl, logger)
            log_assert(os.path.exists(kwargs["ct_bw_mn"][i]),
                       "Cannot find control bw file %s" % kwargs["ct_bw_mn"][i], logger)
            logger.info(bw_pl)
            logger.info(kwargs["ct_bw_mn"][i])
            iplc, imnc, irc = get_coverage_bw(bw_pl=bw_pl,
                                              bw_mn=kwargs["ct_bw_mn"][i],
                                              chromosome_startswith=kwargs["chromosome_startswith"],
                                              output_dir=output_dir,
                                              output_prefix="ct_" + output_prefix + "_%d" % i)

            input_coverage_pl.append(iplc)
            input_coverage_mn.append(imnc)
            ircs.append(irc)
            housekeeping_files.extend(iplc.values())
            housekeeping_files.extend(imnc.values())
        logger.info("Control sample loaded")

    if len(ircs) > 0:
        if len(ircs) == 1 and len(rcs) > 1:
            logger.info("Only one input sample is provided, it will be shared among all treatment libraries")
            for _ in range(len(rcs) - len(ircs)):
                ircs.append(ircs[0])
                input_coverage_pl.append(input_coverage_pl[0])
                input_coverage_mn.append(input_coverage_mn[0])
        for i, (rc, irc), in enumerate(zip(rcs, ircs)):
            scale_factor = rc / irc
            logger.info("Adjusting signals based-on input/control (scale factor: %.4f)" % scale_factor)
            plc, mnc = normalize_using_input(chromosome_coverage_pl[i],
                                             chromosome_coverage_mn[i],
                                             input_coverage_pl[i],
                                             input_coverage_mn[i],
                                             scale_factor=scale_factor,
                                             output_dir=output_dir,
                                             output_prefix=output_prefix + "_inputnorm_%d" % i,
                                             logger=logger)
            chromosome_coverage_pl[i] = plc
            chromosome_coverage_mn[i] = mnc
            logger.info("Signals adjusted.")

            housekeeping_files.extend(plc.values())
            housekeeping_files.extend(mnc.values())

    # peak calling (IQR)
    for rep, pl_cov_dict in enumerate(chromosome_coverage_pl):
        logger.info("Working on sample %d" % (rep + 1))
        sample_prefix = prefix + "_%d" % (rep + 1)
        df_dict = dict()
        for cov_dict, label, spb, anti_label, strand_sign in zip(
                (pl_cov_dict, chromosome_coverage_mn[rep]),
                ("pl", "mn"), (subpeak_pl_beds, subpeak_mn_beds), ("mn", "pl"), ("+", "-")):
            if spb is None:
                continue
            peaks_bed = peaks_single_strand(per_base_cov=cov_dict,
                                            output_file=sample_prefix + "_{0}.bed".format(label),
                                            shared_peak_definitions=spb,
                                            strand_sign=strand_sign,
                                            thread_n=thread_n,
                                            **kwargs)

            peak_df = pd.read_csv(peaks_bed, sep="\t", header=None, names=COMMON_HEADER)
            peak_df = peak_df.loc[peak_df["end"] - peak_df["start"] < kwargs["window_size_threshold"], :]
            df_dict[label] = peak_df

            sig_bins = peak_df.loc[peak_df["padj"] < kwargs["fdr_target"], :]
            with open(sample_prefix + "_sig_%s.bed" % label, "w") as f:
                sig_bins.to_csv(f, sep="\t", index=False, header=False)

        for label, anti_label in zip(("pl", "mn"), ("mn", "pl")):
            if not os.path.exists(sample_prefix + "_sig_%s.bed" % label):
                continue
            merge_opposite_peaks(sample_prefix + "_sig_%s.bed" % label, sample_prefix + "_%s.bed.gz" % anti_label,
                                 divergent_output_bed=sample_prefix + "_sig_%s_divergent_peaks.bed" % label,
                                 bidirectional_output_bed=sample_prefix + "_sig_%s_bidirectional_peaks.bed" % label,
                                 singleton_bed=sample_prefix + "_sig_%s_singletons_peaks.bed" % label,
                                 stringent_only=stringent_only,
                                 **kwargs)

        for directionality in ("bidirectional", "divergent", "singletons"):
            exp_pl_file = sample_prefix + "_sig_pl_%s_peaks.bed" % directionality
            exp_mn_file = sample_prefix + "_sig_mn_%s_peaks.bed" % directionality

            if os.path.exists(exp_pl_file) and os.path.exists(exp_mn_file):
                if directionality != "singletons":
                    pri_merged_file = BedTool.cat(*[BedTool(exp_pl_file),
                                                    BedTool(exp_mn_file)],
                                                  c=(12, 13, 14, 15, 16),
                                                  o=("collapse", "collapse", "collapse", "collapse", "distinct"))
                    pri_merged_df = pri_merged_file.to_dataframe(names=("chrom", "start", "end", "tss_fwd",
                                                                        "tss_fwd_vals", "tss_rev", "tss_rev_vals",
                                                                        "confidence"))
                    for nr, row in pri_merged_df.iterrows():
                        if row["tss_fwd"].find(",") != -1:
                            tsss = row["tss_fwd"].split(",")
                            tss_vals = list(map(int, row["tss_fwd_vals"].split(",")))
                            pri_merged_df.loc[nr, "tss_fwd"] = tsss[np.argmax(tss_vals)]
                            pri_merged_df.loc[nr, "tss_fwd_val"] = max(tss_vals)
                        if row["tss_rev"].find(",") != -1:
                            tsss = row["tss_rev"].split(",")
                            tss_vals = list(map(int, row["tss_rev_vals"].split(",")))
                            pri_merged_df.loc[nr, "tss_rev"] = tsss[np.argmax(tss_vals)]
                            pri_merged_df.loc[nr, "tss_rev_val"] = max(tss_vals)
                    BedTool.from_dataframe(
                        pri_merged_df.loc[:, ("chrom", "start", "end", "confidence", "tss_fwd", "tss_rev")]).saveas(
                        sample_prefix + "_%s_peaks.bed" % directionality)
                else:
                    BedTool.cat(*[BedTool(exp_pl_file),
                                  BedTool(exp_mn_file)],
                                postmerge=False).sort().saveas(sample_prefix + "_unidirectional_peaks.bed")

        if kwargs.get("output_diagnostics", False):
            peak_bed_to_gtf(pl_df=df_dict["pl"], mn_df=df_dict["mn"],
                            save_to=sample_prefix + "_peaks.gtf", version=__version__)

        housekeeping_files.append(sample_prefix + "_pl.bed.gz")
        housekeeping_files.append(sample_prefix + "_pl.bed.gz.tbi")
        housekeeping_files.append(sample_prefix + "_mn.bed.gz")
        housekeeping_files.append(sample_prefix + "_mn.bed.gz.tbi")
        housekeeping_files.append(sample_prefix + "_sig_pl.bed")
        housekeeping_files.append(sample_prefix + "_sig_mn.bed")
        housekeeping_files.append(sample_prefix + "_sig_pl_singletons_peaks.bed")
        housekeeping_files.append(sample_prefix + "_sig_mn_singletons_peaks.bed")
        housekeeping_files.append(sample_prefix + "_sig_pl_divergent_peaks.bed")
        housekeeping_files.append(sample_prefix + "_sig_mn_divergent_peaks.bed")
        housekeeping_files.append(sample_prefix + "_sig_pl_bidirectional_peaks.bed")
        housekeeping_files.append(sample_prefix + "_sig_mn_bidirectional_peaks.bed")

        logger.info("Finished on sample %d" % (rep + 1))
        logger.info("Divergent peaks were saved to %s" % sample_prefix + "_divergent_peaks.bed")
        logger.info("Bidirectional peaks were saved to %s" % sample_prefix + "_bidirectional_peaks.bed")
        logger.info(
            "Significant peaks which failed to pair were saved to %s" % sample_prefix + "_unidirectional_peaks.bed")
    logger.info("Logs were saved to %s" % DEFAULT_PREFIX + ".log")
    # delete intermediate files
    is_borrow_info_from_reps = kwargs.pop("borrow_info_reps", False)
    if is_borrow_info_from_reps and len(chromosome_coverage_pl) > 1:
        logger.info("Enhanced support for biological replicates is enabled.")
        inferring_elements_from_other_reps(prefix=prefix, n_samples=len(chromosome_coverage_pl))
    housekeeping(output_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Peak Identifier for Nascent Transcripts Sequencing")
    group = parser.add_argument_group("Input/Output")
    group.add_argument("--bam-file", action="store", dest="bam_file", nargs="*",
                       type=str, required=False,
                       help="input bam file, if you want to use bigwig files, please use --bw-pl and --bw-mn")
    group.add_argument("--save-to", action="store", dest="save_to",
                       type=str, required=True, default=".",
                       help="save peaks to this path (a folder), by default, current folder")
    group.add_argument("--file-prefix", action="store", dest="file_prefix",
                       type=str, required=True, default=str(os.getpid()),
                       help="prefix to all intermediate files")
    group.add_argument("--bw-pl", action="store", dest="bw_pl", nargs="*",
                       type=str, required=False,
                       help="Bigwig for plus strand. If you want to use bigwig instead of BAM, "
                            "please set bam_file to bigwig")
    group.add_argument("--bw-mn", action="store", dest="bw_mn", nargs="*",
                       type=str, required=False,
                       help="Bigwig for minus strand. If you want to use bigwig instead of BAM, "
                            "please set bam_file to bigwig")
    group.add_argument("--ct-bw-pl", action="store", dest="ct_bw_pl", nargs="*",
                       type=str, required=False,
                       help="Bigwig for control/input (plus strand). If you want to use bigwig instead of BAM, "
                            "please use --ct-bam")
    group.add_argument("--ct-bw-mn", action="store", dest="ct_bw_mn", nargs="*",
                       type=str, required=False,
                       help="Bigwig for input/control (minus strand). If you want to use bigwig instead of BAM, "
                            "please use --ct-bam")
    group.add_argument("--ct-bam", action="store", dest="ct_bam", nargs="*",
                       type=str, required=False,
                       help="Bam file for input/control (minus strand). If you want to use bigwig instead of BAM, "
                            "please use --input-bw-pl and --input-bw-mn")
    group.add_argument("--exp-type", action="store", default="CoPRO", dest="bam_parser",
                       choices=("CoPRO", "GROcap", "PROcap", "CAGE", "NETCAGE", "csRNAseq", "PROseq", "GROseq",
                                "R_5", "R_3", "R1_5", "R1_3", "R2_5", "R2_3"),
                       help="Type of experiment, acceptable values are: CoPRO/GROcap/GROseq/PROcap/PROseq, or if you "
                            "know the position of RNA ends which you're interested on the reads, you can specify "
                            "R_5, R_3, R1_5, R1_3, R2_5 or R2_3")
    group.add_argument("--reverse-complement", action="store_true", dest="seq_reverse_complement",
                       required=False, default=False,
                       help="Set this switch if reads in this library represent the reverse complement of nascent "
                            "RNAs, like PROseq")
    group.add_argument("-f", "--filters", action="store", type=str, nargs="*", default=[],
                       help="reads from chromosomes whose names contain any matches in filters will be ignored")

    group = parser.add_argument_group("Filtering")
    group.add_argument("--adjust-method", action="store", dest="adjust_method",
                       choices=("fdr_bh", "bonferroni", "fdr_tsbh", "fdr_tsbky"),
                       type=str, required=False, default="fdr_bh", help="method for calculating adjusted p-vals")
    group.add_argument("--fdr-target", action="store", dest="fdr_target",
                       type=float, required=False, default=0.1, help="FDR target for multiple testing")
    group.add_argument("--close-threshold", action="store", dest="close_threshold",
                       type=int, required=False, default=300,
                       help="Distance threshold for two peaks (on opposite strands) to be merged")
    group.add_argument("--stringent-pairs-only", action="store_true", dest="stringent_pairs_only",
                       required=False, default=False,
                       help="Only consider elements as bidirectional when both of the two peaks are significant "
                            "according to their p-values")
    group.add_argument("--min-lengths-opposite-peaks", dest="min_len_opposite_peaks",
                       required=False, default=0, type=int,
                       help="Minimum length requirement for peaks on the opposite strand to be paired, set it to 0 to loose this requirement")
    group.add_argument("--mapq-threshold", action="store", dest="mapq_threshold",
                       type=int, required=False, default=30, help="Minimum mapping quality")
    group.add_argument("--read-counts-threshold", action="store", dest="read_counts_threshold",
                       type=int, required=False, default=0,
                       help="Threshold for a window to be considered as having read")
    group.add_argument("--small-peak-threshold", action="store", dest="small_peak_threshold",
                       type=int, required=False, default=5,
                       help="Threshold for small peaks, peaks with width smaller than this value will be required "
                            "to run extra test")
    group.add_argument("--ind-filtering-granularity", action="store", dest="ind_filter_granularity",
                       type=int, required=False, default=0.005,
                       help="Granularity for independent filtering")
    group.add_argument("--window-size", action="store", dest="window_size",
                       type=int, required=False, default=100, help="size for sliding windows")
    group.add_argument("--max-window-size", action="store", dest="window_size_threshold",
                       type=int, required=False, default=2000, help="max size of divergent windows")
    group.add_argument("--step-size", action="store", dest="step_size",
                       type=int, required=False, default=100, help="step size for sliding windows")

    group = parser.add_argument_group("Edge trimming")
    group.add_argument("--annotation-gtf", action="store", dest="annotation_gtf", type=str, required=False,
                       help="Gene annotation file (gtf) format for learning the threshold for edge trimming. "
                            "If this is specified, other related parameters like --donor-tolerance will be ignored.")
    group.add_argument("--focused-chrom", action="store", dest="focused_chrom", default="chr1", type=str,
                       required=False,
                       help="If --annotation-gtf is specified, you use this parameter to change which chromosome the "
                            "tool should learn the values from.")
    group.add_argument("--donor-tolerance", action="store", dest="donor_tolerance",
                       type=float, required=False, default=1.0, help="Donor tolerance in best score segments")
    group.add_argument("--receptor-tolerance", action="store", dest="receptor_tolerance",
                       type=float, required=False, default=0.1, help="Receptor tolerance in best score segments")
    group.add_argument("--ce-trigger", action="store", dest="ce_trigger",
                       type=int, required=False, default=3, help="Trigger for receptor tolerance checking")

    group = parser.add_argument_group("Peak properties")
    group.add_argument("--top-peak-threshold", action="store", dest="top_peak_threshold",
                       type=float, required=False, default=0.75,
                       help="Min size for a divergent peak")
    group.add_argument("--min-mu-percent", action="store", dest="min_mu_percent",
                       type=float, required=False, default=0.1,
                       help="Local backgrounds smaller than this percentile among all peaks will be replaced.")
    group.add_argument("--peak-distance", action="store", dest="peak_distance",
                       type=int, required=False, default=None,
                       help="Required minimal horizontal distance (>= 1) in samples between neighbouring peaks.")
    group.add_argument("--peak-width", action="store", dest="peak_width",
                       type=int, required=False, default=None,
                       help="Required width of peaks in samples.")
    group.add_argument("--peak-rel-height", action="store", dest="peak_rel_height",
                       type=int, required=False, default=1,
                       help="Used for calculation of the peaks width, thus it is only used if width is given.")
    group.add_argument("--div-size-min", action="store", dest="div_size_min",
                       type=int, required=False, default=0,
                       help="Min size for a divergent peak")
    group.add_argument("--summit-dist-min", action="store", dest="summit_dist_min",
                       type=int, required=False, default=0,
                       help="Min dist between two summit")

    group = parser.add_argument_group("Testing")
    group.add_argument("--model", action="store", type=str, required=False, default="ZIP",
                       choices=("ZIP", "NB", "Poisson", "ZINB"),
                       help="Statistical model for testing, can be Poisson, ZIP, NB or ZINB")
    group.add_argument("--IQR-strategy", action="store", type=str, required=False,
                       dest="iqr_strategy", default="bgIQR", choices=("bgIQR", "pkIQR"),
                       help="IQR strategy, can be bgIQR (more robust) or pkIQR (more efficient)")
    group.add_argument("--disable-ler", action="store_true", required=False, default=False,
                       help="Disable Local Environment Refinement")
    group.add_argument("--disable-small", action="store_true", required=False, default=False,
                       help="Disable small peak correction")

    group = parser.add_argument_group("Other")
    group.add_argument("--chromosome-start-with", action="store", dest="chromosome_startswith",
                       type=str, required=False, default="chr",
                       help="Only keep reads mapped to chromosomes with this prefix")
    group.add_argument("--dont-output-chrom-size", action="store_false", dest="output_chrom_size",
                       required=False, default=True,
                       help="Don't write chromosome dict to local folder (not recommended)")
    group.add_argument("--debug", action="store_true", dest="output_diagnostics",
                       required=False, default=False,
                       help="Save diagnostics (independent filtering and pval dist) to local folder")
    group.add_argument("--borrow-info-reps", action="store_true", dest="borrow_info_reps",
                       required=False, default=False,
                       help="Borrow information from reps to refine calling of divergent elements")
    group.add_argument("--thread", action="store", dest="thread_n",
                       type=int, required=False, default=1,
                       help="Max number of threads PINTS can create")
    parser.add_argument("-v", "--version", action="version", version=__version__)

    args = parser.parse_args()

    warnings.filterwarnings("ignore", message="numpy.dtype size changed")
    warnings.filterwarnings("ignore", message="numpy.ufunc size changed")
    DEFAULT_PREFIX = "peakcalling_" + datetime.datetime.now().strftime("%Y_%M_%d_%H_%M_%S")
    handler = logging.FileHandler(os.path.join(args.save_to, "{0}.log".format(DEFAULT_PREFIX)))
    formatter = logging.Formatter("%(name)s - %(asctime)s - %(levelname)s: %(message)s")
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    # redirect exception message to log
    sys.excepthook = handle_exception

    peak_calling(args.bam_file, args.save_to, args.file_prefix, close_threshold=args.close_threshold,
                 window_size=args.window_size, step_size=args.step_size, fdr_target=args.fdr_target,
                 adjust_method=args.adjust_method, mapq_threshold=args.mapq_threshold,
                 read_counts_threshold=args.read_counts_threshold, small_peak_threshold=args.small_peak_threshold,
                 chromosome_startswith=args.chromosome_startswith, window_size_threshold=args.window_size_threshold,
                 peak_rel_height=args.peak_rel_height, gene_annotation=args.annotation_gtf,
                 output_chrom_size=args.output_chrom_size, output_diagnostics=args.output_diagnostics,
                 ind_filter_granularity=args.ind_filter_granularity, thread_n=args.thread_n,
                 div_size_min=args.div_size_min, summit_dist_min=args.summit_dist_min,
                 top_peak_threshold=args.top_peak_threshold, bw_pl=args.bw_pl, bw_mn=args.bw_mn,
                 donor_tolerance=args.donor_tolerance, receptor_tolerance=args.receptor_tolerance,
                 ce_trigger=args.ce_trigger, bam_parser=args.bam_parser, seq_rc=args.seq_reverse_complement,
                 filters=args.filters, min_mu_percent=args.min_mu_percent, ct_bam=args.ct_bam, ct_bw_pl=args.ct_bw_pl,
                 ct_bw_mn=args.ct_bw_mn, model=args.model, iqr_strategy=args.iqr_strategy,
                 highlight_chromosome=args.focused_chrom, disable_ler=args.disable_ler,
                 stringent_only=args.stringent_pairs_only, min_len_opposite_peaks=args.min_len_opposite_peaks,
                 disable_small=args.disable_small, borrow_info_reps=args.borrow_info_reps)
