#!/usr/bin/env python
import deeptools.heatmapper as heatmapper
import numpy as np
import argparse
import sys

def parse_arguments():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.RawDescriptionHelpFormatter,
        description="""
This tool subsets or reorders a file produced by computeMatrix by group and/or sample.

detailed help:

  subsetMatrix info -h

or

  subsetMatrix subset -h

or

  subsetMatrix filterStrand -h

or

  subsetMatrix rbind -h

or
  subsetMatrix cbind -h

""",
        epilog='example usages:\n'
               'subsetMatrix subset -m input.mat.gz -o output.mat.gz --group "group 1" "group 2" --samples "sample 3" "sample 10"\n\n'
               ' \n\n')

    subparsers = parser.add_subparsers(
        title='Commands',
        dest='command',
        metavar='')

    # info
    subparsers.add_parser(
        'info',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[infoArgs()],
        help="Print group and sample information",
        usage='An example usage is:\n  subsetMatrix info -m input.mat.gz\n\n')

    # subset
    subparsers.add_parser(
        'subset',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[infoArgs(), subsetArgs()],
        help="Actually subset the matrix. The group and sample orders are honored, so one can also reorder files.",
        usage='An example usage is:\n  subsetMatrix subset -m '
        'input.mat.gz -o output.mat.gz --groups "group 1" "group 2" '
        '--samples "sample 3" "sample 10"\n\n')

    # filterStrand
    subparsers.add_parser(
        'filterStrand',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[infoArgs(), filterStrandArgs()],
        help="Filter entries by strand.",
        usage='Example usage:\n  subsetMatrix filterStrand -m '
        'input.mat.gz -o output.mat.gz --strand +\n\n')

    # rbind
    subparsers.add_parser(
        'rbind',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[bindArgs()],
        help="merge multiple matrices by concatenating them head to tail. This assumes that the same samples are present in each in the same order.",
        usage='Example usage:\n  subsetMatrix rbind -m '
        'input1.mat.gz input2.mat.gz -o output.mat.gz\n\n')

    # cbind
    subparsers.add_parser(
        'cbind',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
        parents=[bindArgs()],
        help="merge multiple matrices by concatenating them left to right. No assumptions are made about the row order, though all region names must be present in all matrices.",
        usage='Example usage:\n  subsetMatrix cbind -m '
        'input1.mat.gz input2.mat.gz -o output.mat.gz\n\n')

    return parser


def bindArgs():
    parser = argparse.ArgumentParser(add_help=False)
    required = parser.add_argument_group('Required arguments')

    required.add_argument('--matrixFile', '-m',
                          help='Matrix files from the computeMatrix tool.',
                          nargs='+',
                          required=True)

    required.add_argument('--outFileName', '-o',
                          help='Output file name',
                          required=True)

    return parser


def infoArgs():
    parser = argparse.ArgumentParser(add_help=False)
    required = parser.add_argument_group('Required arguments')

    required.add_argument('--matrixFile', '-m',
                          help='Matrix file from the computeMatrix tool.',
                          required=True)

    return parser


def subsetArgs():
    parser = argparse.ArgumentParser(add_help=False)
    required = parser.add_argument_group('Required arguments')

    required.add_argument('--outFileName', '-o',
                          help='Output file name',
                          required=True)

    optional = parser.add_argument_group('Optional arguments')

    optional.add_argument('--groups',
                          nargs='+',
                          help="Groups to include. If none are specified then all will be included.")

    optional.add_argument('--samples',
                          nargs='+',
                          help="Samples to include. If none are specified then all will be included.")

    return parser


def filterStrandArgs():
    parser = argparse.ArgumentParser(add_help=False)
    required = parser.add_argument_group('Required arguments')

    required.add_argument('--outFileName', '-o',
                          help='Output file name',
                          required=True)

    required.add_argument('--strand', '-s',
                          help='Strand',
                          choices=['+', '-', '.'],
                          required=True)

    return parser


def printInfo(matrix):
    """
    Print the groups and samples
    """

    print("Groups:")
    for group in matrix.matrix.group_labels:
        print("\t{0}".format(group))

    print("Samples:")
    for sample in matrix.matrix.sample_labels:
        print("\t{0}".format(sample))


def getGroupBounds(args, matrix):
    """
    Given the group labels, return an indexing array and the resulting boundaries
    """
    bounds = matrix.parameters['group_boundaries']
    if args.groups is None:
        return range(0, matrix.matrix.matrix.shape[0]), np.array(bounds)
    else:
        o = list()
        obounds = [0]
        for group in args.groups:
            if group not in matrix.matrix.group_labels:
                sys.exit("Error: '{0}' is not a valid group\n".format(group))
            idx = matrix.matrix.group_labels.index(group)
            o.extend(range(bounds[idx], bounds[idx + 1]))
            obounds.append(bounds[idx + 1] - bounds[idx])
        return o, np.cumsum(obounds)


def getSampleBounds(args, matrix):
    """
    Given the sample labels, return an indexing array
    """
    bounds = matrix.parameters['sample_boundaries']
    if args.samples is None:
        return np.arange(0, matrix.matrix.matrix.shape[1])
    else:
        o = list()
        for sample in args.samples:
            if sample not in matrix.matrix.sample_labels:
                sys.exit("Error: '{0}' is not a valid sample\n".format(sample))
            idx = matrix.matrix.sample_labels.index(sample)
            o.extend(range(bounds[idx], bounds[idx + 1]))
        return o


def subsetRegions(matrix, bounds):
    out = []
    for x in bounds:
        reg = matrix.matrix.regions[x]
        # we need to add a list of [chrom, [(start, end), (start, end)], name, 0, strand, score)] 
        starts = reg["start"].split(",")
        starts = [int(x) for x in starts]
        ends = reg["end"].split(",")
        ends = [int(x) for x in ends]
        regs = [(x, y) for x, y in zip(starts, ends)]
        out.append([reg["chrom"], regs, reg["name"], 0, reg["strand"], reg["score"]])
    return out


def filterHeatmap(hm, args):
    bounds = [0]
    keep = []
    total = 0
    for region in hm.matrix.regions:
        if region["strand"] == args.strand:
            keep.append(True)
        else:
            keep.append(False)
    keep = np.array(keep)

    # Get the new bounds
    for idx in range(1, len(hm.matrix.group_boundaries)):
        i = np.sum(keep[hm.matrix.group_boundaries[idx - 1]:hm.matrix.group_boundaries[idx]])
        bounds.append(bounds[idx - 1] + i)

    hm.matrix.group_boundaries = bounds

    # subset the matrix
    hm.matrix.matrix = hm.matrix.matrix[keep,]

    # subset the regions
    regions = []
    for idx, reg in enumerate(hm.matrix.regions):
        if keep[idx]:
            foo = [(x, y) for x, y in zip(reg["start"].split(","), reg["end"].split(","))]
            regions.append([reg["chrom"], foo, reg["name"], 0, reg["strand"], reg["score"]])
    hm.matrix.regions = regions


def rbindMatrices(hm, args):
    """
    This only supports a single group at this point

    It's assumed that the same samples are present in both and in the exact same order
    """
    hm2 = heatmapper.heatmapper()
    hm.read_matrix_file(args.matrixFile[0])
    for idx in range(1, len(args.matrixFile)):
        hm2.read_matrix_file(args.matrixFile[idx])
        hm.matrix.group_boundaries[-1] += hm2.matrix.group_boundaries[-1]
        hm.matrix.regions.extend(hm2.matrix.regions)
        hm.matrix.matrix = np.concatenate([hm.matrix.matrix, hm2.matrix.matrix], axis=0)

    regions = []
    for reg in hm.matrix.regions:
        foo = [(x, y) for x, y in zip(reg["start"].split(","), reg["end"].split(","))]
        regions.append([reg["chrom"], foo, reg["name"], 0, reg["strand"], reg["score"]])
    hm.matrix.regions = regions


def cbindMatrices(hm, args):
    """
    All regions (identified by their names) MUST be present in each matrix
    """
    hm2 = heatmapper.heatmapper()

    # Make a dict of region name:row associations
    hm.read_matrix_file(args.matrixFile[0])
    d = dict()
    for idx, reg in enumerate(hm.matrix.regions):
        d[reg['name']] = idx

    # Iterate through the other matrices
    for idx in range(1, len(args.matrixFile)):
        hm2.read_matrix_file(args.matrixFile[idx])
        # Add the sample labels
        hm.parameters['sample_labels'].extend(hm2.parameters['sample_labels'])
        # Add the sample boundaries
        lens = [x + hm.parameters['sample_boundaries'][-1] for x in hm2.parameters['sample_boundaries']][1:]
        hm.parameters['sample_boundaries'].extend(lens)

        # Stack after creating an appropriate view
        l = []
        for idx2 in range(0, hm2.matrix.matrix.shape[0]):
            rowNum = d[hm2.matrix.regions[idx2]['name']]
            l.append(rowNum)
        hm.matrix.matrix = np.hstack((hm.matrix.matrix, hm2.matrix.matrix[np.array(l),:]))

    regions = []
    for reg in hm.matrix.regions:
        foo = [(x, y) for x, y in zip(reg["start"].split(","), reg["end"].split(","))]
        regions.append([reg["chrom"], foo, reg["name"], 0, reg["strand"], reg["score"]])
    hm.matrix.regions = regions


def main():
    args = parse_arguments().parse_args()

    hm = heatmapper.heatmapper()
    if not isinstance(args.matrixFile, list):
        hm.read_matrix_file(args.matrixFile)

    if args.command == 'info':
        printInfo(hm)
    elif args.command == 'subset':
        sIdx = getSampleBounds(args, hm)
        gIdx, gBounds = getGroupBounds(args, hm)

        # groups
        hm.matrix.regions = subsetRegions(hm, gIdx)
        # matrix
        hm.matrix.matrix = hm.matrix.matrix[gIdx, :]
        hm.matrix.matrix = hm.matrix.matrix[:, sIdx]
        # boundaries
        if args.samples is None:
            args.samples = hm.matrix.sample_labels
        hm.matrix.sample_boundaries = hm.matrix.sample_boundaries[0:len(args.samples) + 1]
        hm.matrix.group_boundaries = gBounds.tolist()
        # labels
        hm.matrix.sample_labels = args.samples
        if args.groups is None:
            args.groups = hm.matrix.group_labels
        hm.matrix.group_labels = args.groups
        # save
        hm.save_matrix(args.outFileName)
    elif args.command == 'filterStrand':
        filterHeatmap(hm, args)
        hm.save_matrix(args.outFileName)
    elif args.command == 'rbind':
        rbindMatrices(hm, args)
        hm.save_matrix(args.outFileName)
    elif args.command == 'cbind':
        cbindMatrices(hm, args)
        hm.save_matrix(args.outFileName)
    else:
        sys.exit("Unknown command {0}!\n".format(args.command))


if __name__ == "__main__":
    main()
