#!/usr/bin/env python

from __future__ import print_function
import argparse
import os.path
import textwrap
import sys

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')


class TADtool(object):

    def __init__(self):
        usage = '''\
                tadtool <command> [options]

                Commands:
                    plot                Main interactive TADtool plotting window
                    tads                Call TADs with pre-defined parameters
                    subset              Reduce a matrix to a smaller region.

                Run tadtool <command> -h for help on a specific command.
                '''
        parser = argparse.ArgumentParser(
            description="Assistant to find cutoffs in TAD calling algorithms.",
            usage=textwrap.dedent(usage)
        )

        parser.add_argument('command', help='Subcommand to run')

        # parse_args defaults to [1:] for args, but you need to
        # exclude the rest of the args too, or validation will fail
        args = parser.parse_args(sys.argv[1:2])
        if not hasattr(self, args.command):
            print('Unrecognized command')
            parser.print_help()
            exit(1)
        # use dispatch pattern to invoke method with same name
        getattr(self, args.command)()

    ##########################################################################
    #
    #                                   Auto
    #
    ##########################################################################

    def plot(self):
        parser = argparse.ArgumentParser(
            prog="tadtool " + sys.argv[1],
            description='Main interactive TADtool plotting window'
        )

        parser.add_argument(
            'matrix',
            help='''Square Hi-C Matrix as tab-delimited or .npy file (created with numpy.save) 
                    or sparse matrix format (each line: 
                    <row region index> <column region index> <matrix value>)'''
        )

        parser.add_argument(
            'regions',
            help='''BED file (no header) with regions corresponding to the number of rows in the provided matrix.
                    Fourth column, if present, denotes name field, which is used as an identifier in sparse 
                    matrix notation.'''
        )

        parser.add_argument(
            'plotting_region',
            help='''
                Region of the Hi-C matrix to display in plot. Format: <chromosome>:<start>-<end>, e.g.
                chr12:31000000-33000000
            '''
        )

        parser.add_argument(
            '-w', '--window-sizes', dest='window_sizes',
            nargs='+',
            help='''
                Window sizes in base pairs used for TAD calculation. You can pass
                (1) a filename with whitespace-delimited window sizes,
                (2) three integers denoting start, stop, and step size to generate a range of window sizes, or
                (3) more than three integers to define window sizes directly.
                If left at default, window sizes will be logarithmically spaced between 10**4 and 10**6, or 10**6.5
                for the insulation and directionality index, respectively.
            '''
        )

        parser.add_argument(
            '-a', '--algorithm', dest='algorithm',
            default='insulation',
            help='''TAD-calling algorithm. Options: insulation, ninsulation, directionality. Default: insulation.'''
        )

        parser.add_argument(
            '-m', '--max-distance', dest='max_dist',
            type=int,
            help='''Maximum distance in base-pairs away from the diagonal to be shown in Hi-C plot.
                    Defaults to half the plotting window.'''
        )

        parser.add_argument(
            '-n', '--normalisation-window', dest='normalisation_window',
            type=int,
            help='''Normalisation window in number of regions. Only affects ninsulation algorithm. If not specified,
                    window will be the whole chromosome.'''
        )

        parser.add_argument(
            '-d', '--data', dest='data',
            help='''
                Matrix with index data. Rows correspond to window sizes, columns to Hi-C matrix bins.
                If provided, suppresses inbuilt index calculation.
            '''
        )

        args = parser.parse_args(sys.argv[2:])
        regions_file = os.path.expanduser(args.regions)
        matrix_file = os.path.expanduser(args.matrix)
        data_file = None if args.data is None else os.path.expanduser(args.data)
        plotting_region = args.plotting_region
        max_dist = args.max_dist

        import tadtool.tad as tad
        from functools import partial

        pr = tad.GenomicRegion.from_string(plotting_region)
        if max_dist is None:
            max_dist = int((pr.end - pr.start)/2)

        logging.info("Loading regions...")
        regions, ix_converter = tad.load_regions(regions_file)

        logging.info("Checking plotting region in matrix...")
        tad.sub_regions(regions, plotting_region)

        logging.info("Loading matrix...")
        matrix = tad.load_matrix(matrix_file, len(regions), ix_converter=ix_converter)
        matrix = tad.masked_matrix(matrix)

        if len(regions) != matrix.shape[0]:
            raise ValueError("Regions ({}) must be the same length as rows in matrix ({})!".format(len(regions),
                                                                                                   matrix.shape[0]))

        if matrix.shape[0] != matrix.shape[1]:
            raise ValueError("Matrix must be square!")

        logging.info("Getting window sizes...")
        if args.window_sizes is None:
            import numpy as np
            if args.algorithm == 'insulation' or args.algorithm == 'ninsulation':
                window_sizes = [int(i) for i in np.logspace(4, 6, 100)]
            else:
                window_sizes = [int(i) for i in np.logspace(4, 6.5, 100)]
        else:
            if len(args.window_sizes) == 1:
                try:
                    import numpy as np
                    window_sizes = np.load(args.window_sizes[0])
                except IOError:
                    with open(os.path.expanduser(args.window_sizes[0]), 'r') as ws_file:
                        file_content = ws_file.read()
                        string_sizes = file_content.split()
                        window_sizes = [int(string_size) for string_size in string_sizes]
            elif len(args.window_sizes) == 3:
                window_sizes = range(int(args.window_sizes[0]), int(args.window_sizes[1]), int(args.window_sizes[2]))
            elif len(args.window_sizes) > 3:
                window_sizes = [int(window_size) for window_size in args.window_sizes]
            else:
                raise ValueError("Wrong widow size declaration, see help (-h) for details!")

        normalised_insulation_index = partial(tad.normalised_insulation_index,
                                              normalisation_window=args.normalisation_window)

        if args.data is None:
            if args.algorithm == 'insulation':
                tad_algorithm = tad.insulation_index
            elif args.algorithm == 'ninsulation':
                tad_algorithm = normalised_insulation_index
            elif args.algorithm == 'directionality':
                tad_algorithm = tad.directionality_index
            else:
                raise ValueError("Algorithm (-a) can only be directionality, insulation or ninsulation,"
                                 " not %s!" % args.algorithm)

            logging.info("Calculating index...")
            data, window_sizes = tad.data_array(hic_matrix=matrix, regions=regions,
                                                tad_method=tad_algorithm, window_sizes=window_sizes)
        else:
            logging.info("Loading data file...")
            data = tad.load_matrix(data_file, square=False)
            if data.shape[0] != len(window_sizes):
                raise ValueError("Number of rows in data (%d) must equal the number of window sizes (%d)!" % (
                    data.shape[0], len(window_sizes)))
            if data.shape[1] != matrix.shape[0]:
                raise ValueError("Number of columns in data (%d) must equal matrix rows (%d)!" % (
                    data.shape[0], matrix.shape[0]))

        import tadtool.plot as tp
        logging.info("Done. Showing plot...")
        tad_plot = tp.TADtoolPlot(matrix, regions, data=data, window_sizes=window_sizes,
                                  norm='lin', max_dist=max_dist, algorithm=args.algorithm)
        fig, axes = tad_plot.plot(plotting_region)
        tp.plt.show(block=True)  # necessary so that window won't close immediately

    def tads(self):
        parser = argparse.ArgumentParser(
            prog="tadtool " + sys.argv[1],
            description='Call TADs with pre-defined parameters'
        )

        parser.add_argument(
            'matrix',
            help='''Square Hi-C Matrix as tab-delimited or .npy file (created with numpy.save)
                    or sparse matrix format (each line: 
                    <row region index> <column region index> <matrix value>)'''
        )

        parser.add_argument(
            'regions',
            help='''BED file (no header) with regions corresponding to the number of rows in the provided matrix.
                    Fourth column, if present, denotes name field, which is used as an identifier in sparse 
                    matrix notation.'''
        )

        parser.add_argument(
            'window_size',
            type=int,
            help='''Window size in base pairs'''
        )

        parser.add_argument(
            'cutoff',
            type=float,
            help='''Cutoff for TAD-calling algorithm at given window size.'''
        )

        parser.add_argument(
            'output',
            nargs='?',
            help='''Optional output file to save TADs.'''
        )

        parser.add_argument(
            '-a', '--algorithm', dest='algorithm',
            default='insulation',
            help='''TAD-calling algorithm. Options: insulation, ninsulation, directionality. Default: insulation.'''
        )

        parser.add_argument(
            '-n', '--normalisation-window', dest='normalisation_window',
            type=int,
            help='''Normalisation window in number of regions. Only affects ninsulation algorithm. If not specified,
                    window will be the whole chromosome.'''
        )

        parser.add_argument(
            '-v', '--write-values', dest='write_values',
            action='store_true',
            help='''Write index values to file instead of TADs.'''
        )
        parser.set_defaults(write_values=False)

        args = parser.parse_args(sys.argv[2:])
        regions_file = os.path.expanduser(args.regions)
        matrix_file = os.path.expanduser(args.matrix)
        write_values = args.write_values

        import tadtool.tad as tad
        from functools import partial

        logging.info("Loading regions...")
        regions, ix_converter = tad.load_regions(regions_file)

        logging.info("Loading matrix...")
        matrix = tad.load_matrix(matrix_file, len(regions), ix_converter=ix_converter)

        if len(regions) != matrix.shape[0]:
            raise ValueError("Regions must be the same length as rows in matrix!")

        normalised_insulation_index = partial(tad.normalised_insulation_index,
                                              normalisation_window=args.normalisation_window)

        if args.algorithm == 'insulation':
            index_algorithm = tad.insulation_index
            tad_algorithm = tad.call_tads_insulation_index
        elif args.algorithm == 'ninsulation':
            index_algorithm = normalised_insulation_index
            tad_algorithm = tad.call_tads_insulation_index
        elif args.algorithm == 'directionality':
            index_algorithm = tad.directionality_index
            tad_algorithm = tad.call_tads_directionality_index
        else:
            raise ValueError("Algorithm (-a) can only be directionality, insulation or ninsulation, "
                             "not %s!" % args.algorithm)

        logging.info("Calculating index...")
        index = index_algorithm(matrix, regions=regions, window_size=args.window_size)

        if write_values:
            if args.output is not None:
                output = os.path.expanduser(args.output)
                with open(output, 'w') as o:
                    for i in range(len(regions)):
                        ii = index[i]
                        r = regions[i]
                        o.write("{}\t{}\t{}\t.\t{}\n".format(r.chromosome, r.start-1, r.end, ii))

            else:
                for i in range(len(regions)):
                    ii = index[i]
                    r = regions[i]
                    print("{}\t{}\t{}\t.\t{}\n".format(r.chromosome, r.start - 1, r.end, ii))
        else:
            logging.info("Calling TADs")
            tads = tad_algorithm(index, cutoff=args.cutoff, regions=regions)

            if args.output is not None:
                output = os.path.expanduser(args.output)
                with open(output, 'w') as o:
                    for region in tads:
                        o.write("%s\t%d\t%d\n" % (region.chromosome, region.start-1, region.end))
            else:
                for region in tads:
                    print("%s\t%d\t%d" % (region.chromosome, region.start-1, region.end))

    def subset(self):
        parser = argparse.ArgumentParser(
            prog="tadtool " + sys.argv[1],
            description='Reduce a matrix to a smaller region.'
        )

        parser.add_argument(
            'matrix',
            help='''Square Hi-C Matrix as tab-delimited or .npy file (created with numpy.save)
                    or sparse matrix format (each line: 
                    <row region index> <column region index> <matrix value>)'''
        )

        parser.add_argument(
            'regions',
            help='''BED file (no header) with regions corresponding to the number of rows in the provided matrix.
                    Fourth column, if present, denotes name field, which is used as an identifier in sparse 
                    matrix notation.'''
        )

        parser.add_argument(
            'sub_region',
            help='''
                Region of the Hi-C matrix to display in plot. Format: <chromosome>:<start>-<end>, e.g.
                chr12:31000000-33000000
            '''
        )

        parser.add_argument(
            'output_matrix',
            help='''Output matrix file.'''
        )

        parser.add_argument(
            'output_regions',
            help='''Output regions file.'''
        )

        args = parser.parse_args(sys.argv[2:])
        regions_file = os.path.expanduser(args.regions)
        matrix_file = os.path.expanduser(args.matrix)
        output_matrix_file = os.path.expanduser(args.output_matrix)
        output_regions_file = os.path.expanduser(args.output_regions)
        sub_region = args.sub_region

        import tadtool.tad as tad
        import numpy as np

        regions, ix_converter = tad.load_regions(regions_file)

        sr, start_ix, end_ix = tad.sub_regions(regions, sub_region)

        with open(output_regions_file, 'w') as o:
            for r in sr:
                o.write("{}\t{}\t{}\n".format(r.chromosome, r.start-1, r.end))

        m = None

        try:  # numpy binary format
            m = np.load(matrix_file)
        except IOError:  # not an .npy file
            # check number of fields in file
            with open(matrix_file, 'r') as f:
                line = f.readline()
                while line.startswith('#'):
                    line = f.readline()
                line = line.rstrip()
                n_fields = len(line.split("\t"))

            if n_fields > 3:  # square matrix format
                m = np.loadtxt(matrix_file)

        if m is not None:
            m_sub = m[start_ix:end_ix+1, start_ix:end_ix+1]
            np.save(output_matrix_file, m_sub)
        else:
            edges = []
            with open(matrix_file, 'r') as f:
                for line in f:
                    if line.startswith("#"):
                        continue
                    line = line.rstrip()
                    fields = line.split("\t")
                    if ix_converter is None:
                        source, sink, weight = int(fields[0]), int(fields[1]), float(fields[2])
                    else:
                        source = ix_converter[fields[0]]
                        sink = ix_converter[fields[1]]
                        weight = float(fields[2])

                    if source < start_ix or sink > end_ix:
                        continue

                    edges.append((source - start_ix, sink - start_ix, weight))

            with open(output_matrix_file, 'w') as o:
                for (source, sink, weight) in edges:
                    o.write("{}\t{}\t{}\n".format(source, sink, weight))


if __name__ == '__main__':
    TADtool()
