#!/usr/bin/env python

from __future__ import print_function
import argparse
import os.path
import textwrap
import sys

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')


class TADtool(object):

    def __init__(self):
        usage = '''\
                tadtool <command> [options]

                Commands:
                    plot                Main interactive TADtool plotting window
                    tads                Call TADs with pre-defined parameters

                Run tadtool <command> -h for help on a specific command.
                '''
        parser = argparse.ArgumentParser(
            description="Assistant to find cutoffs in TAD calling algorithms.",
            usage=textwrap.dedent(usage)
        )

        parser.add_argument('command', help='Subcommand to run')

        # parse_args defaults to [1:] for args, but you need to
        # exclude the rest of the args too, or validation will fail
        args = parser.parse_args(sys.argv[1:2])
        if not hasattr(self, args.command):
            print('Unrecognized command')
            parser.print_help()
            exit(1)
        # use dispatch pattern to invoke method with same name
        getattr(self, args.command)()

    ##########################################################################
    #
    #                                   Auto
    #
    ##########################################################################

    def plot(self):
        parser = argparse.ArgumentParser(
            prog="tadtool " + sys.argv[1],
            description='Main interactive TADtool plotting window'
        )

        parser.add_argument(
            'matrix',
            help='''Square Hi-C Matrix as tab-delimited or .npy file (created with numpy.save)'''
        )

        parser.add_argument(
            'regions',
            help='''BED3 file (no header) with regions corresponding to the number of rows in the provided matrix.'''
        )

        parser.add_argument(
            'plotting_region',
            help='''
                Region of the Hi-C matrix to display in plot. Format: <chromosome>:<start>-<end>, e.g.
                chr12:31000000-33000000
            '''
        )

        parser.add_argument(
            '-w', '--window-sizes', dest='window_sizes',
            nargs='+',
            help='''
                Window sizes in base pairs used for TAD calculation. You can pass
                (1) a filename with whitespace-delimited window sizes,
                (2) three integers denoting start, stop, and step size to generate a range of window sizes, or
                (3) more than three integers to define window sizes directly.
                If left at default, window sizes will be logarithmically spaced between 10**4 and 10**6, or 10**6.5
                for the insulation and directionality index, respectively.
            '''
        )

        parser.add_argument(
            '-a', '--algorithm', dest='algorithm',
            default='insulation',
            help='''TAD-calling algorithm. Options: insulation, ninsulation, directionality. Default: insulation.'''
        )

        parser.add_argument(
            '-m', '--max-distance', dest='max_dist',
            type=int,
            default=3000000,
            help='''Maximum distance in base-pairs away from the diagonal to be shown in Hi-C plot. Default: 3000000'''
        )

        parser.add_argument(
            '-n', '--normalisation-window', dest='normalisation_window',
            type=int,
            help='''Normalisation window in number of regions. Only affects ninsulation algorithm. If not specified,
                    window will be the whole chromosome.'''
        )

        parser.add_argument(
            '-d', '--data', dest='data',
            help='''
                Matrix with index data. Rows correspond to window sizes, columns to Hi-C matrix bins.
                If provided, suppresses inbuilt index calculation.
            '''
        )

        args = parser.parse_args(sys.argv[2:])

        import tadtool.tad as tad
        from functools import partial

        logging.info("Loading matrix...")
        matrix = tad.HicMatrixFileReader().matrix(os.path.expanduser(args.matrix))
        logging.info("Loading regions...")
        regions = tad.HicRegionFileReader().regions(os.path.expanduser(args.regions))

        if len(regions) != matrix.shape[0]:
            raise ValueError("Regions must be the same length as rows in matrix!")

        if matrix.shape[0] != matrix.shape[1]:
            raise ValueError("Matrix must be square!")

        logging.info("Getting window sizes...")
        if args.window_sizes is None:
            import numpy as np
            if args.algorithm == 'insulation' or args.algorithm == 'ninsulation':
                window_sizes = [int(i) for i in np.logspace(4, 6, 100)]
            else:
                window_sizes = [int(i) for i in np.logspace(4, 6.5, 100)]
        else:
            if len(args.window_sizes) == 1:
                try:
                    import numpy as np
                    window_sizes = np.load(args.window_sizes[0])
                except IOError:
                    with open(os.path.expanduser(args.window_sizes[0]), 'r') as ws_file:
                        file_content = ws_file.read()
                        string_sizes = file_content.split()
                        window_sizes = [int(string_size) for string_size in string_sizes]
            elif len(args.window_sizes) == 3:
                window_sizes = range(int(args.window_sizes[0]), int(args.window_sizes[1]), int(args.window_sizes[2]))
            elif len(args.window_sizes) > 3:
                window_sizes = [int(window_size) for window_size in args.window_sizes]
            else:
                raise ValueError("Wrong widow size declaration, see help (-h) for details!")

        normalised_insulation_index = partial(tad.normalised_insulation_index,
                                              normalisation_window=args.normalisation_window)

        if args.data is None:
            if args.algorithm == 'insulation':
                tad_algorithm = tad.insulation_index
            elif args.algorithm == 'ninsulation':
                tad_algorithm = normalised_insulation_index
            elif args.algorithm == 'directionality':
                tad_algorithm = tad.directionality_index
            else:
                raise ValueError("Algorithm (-a) can only be directionality, insulation or ninsulation,"
                                 " not %s!" % args.algorithm)

            data, window_sizes = tad.data_array(hic_matrix=matrix, regions=regions,
                                                tad_method=tad_algorithm, window_sizes=window_sizes)
        else:
            logging.info("Loading data file...")
            data = tad.HicMatrixFileReader().matrix(os.path.expanduser(args.data))
            if data.shape[0] != len(window_sizes):
                raise ValueError("Number of rows in data (%d) must equal the number of window sizes (%d)!" % (
                data.shape[0], len(window_sizes)))
            if data.shape[1] != matrix.shape[0]:
                raise ValueError("Number of columns in data (%d) must equal matrix rows (%d)!" % (
                data.shape[0], matrix.shape[0]))

        import tadtool.plot as tp
        logging.info("Done. Showing plot...")
        tad_plot = tp.TADtoolPlot(matrix, regions, data=data, window_sizes=window_sizes,
                                  norm='lin', max_dist=args.max_dist, algorithm=args.algorithm)
        fig, axes = tad_plot.plot(args.plotting_region)
        tp.plt.show(block=True)  # necessary so that window won't close immediately

    def tads(self):
        parser = argparse.ArgumentParser(
            prog="tadtool " + sys.argv[1],
            description='Call TADs with pre-defined parameters'
        )

        parser.add_argument(
            'matrix',
            help='''Square Hi-C Matrix as tab-delimited or .npy file (created with numpy.save)'''
        )

        parser.add_argument(
            'regions',
            help='''BED3 file (no header) with regions corresponding to the number of rows in the provided matrix.'''
        )

        parser.add_argument(
            'window_size',
            type=int,
            help='''Window size in base pairs'''
        )

        parser.add_argument(
            'cutoff',
            type=float,
            help='''Cutoff for TAD-calling algorithm at given window size.'''
        )

        parser.add_argument(
            'output',
            nargs='?',
            help='''Optional output file to save TADs.'''
        )

        parser.add_argument(
            '-a', '--algorithm', dest='algorithm',
            default='insulation',
            help='''TAD-calling algorithm. Options: insulation, ninsulation, directionality. Default: insulation.'''
        )

        parser.add_argument(
            '-n', '--normalisation-window', dest='normalisation_window',
            type=int,
            help='''Normalisation window in number of regions. Only affects ninsulation algorithm. If not specified,
                    window will be the whole chromosome.'''
        )

        args = parser.parse_args(sys.argv[2:])

        import tadtool.tad as tad
        from functools import partial

        logging.info("Loading matrix...")
        matrix = tad.HicMatrixFileReader().matrix(os.path.expanduser(args.matrix))
        logging.info("Loading regions...")
        regions = tad.HicRegionFileReader().regions(os.path.expanduser(args.regions))

        if len(regions) != matrix.shape[0]:
            raise ValueError("Regions must be the same length as rows in matrix!")

        normalised_insulation_index = partial(tad.normalised_insulation_index,
                                              normalisation_window=args.normalisation_window)

        if args.algorithm == 'insulation':
            index_algorithm = tad.insulation_index
            tad_algorithm = tad.call_tads_insulation_index
        elif args.algorithm == 'ninsulation':
            index_algorithm = normalised_insulation_index
            tad_algorithm = tad.call_tads_insulation_index
        elif args.algorithm == 'directionality':
            index_algorithm = tad.directionality_index
            tad_algorithm = tad.call_tads_directionality_index
        else:
            raise ValueError("Algorithm (-a) can only be directionality, insulation or ninsulation, "
                             "not %s!" % args.algorithm)

        logging.info("Calculating index...")
        index = index_algorithm(matrix, regions=regions, window_size=args.window_size)
        logging.info("Calling TADs")
        tads = tad_algorithm(index, cutoff=args.cutoff, regions=regions)

        if args.output is not None:
            output = os.path.expanduser(args.output)
            with open(output, 'w') as o:
                for region in tads:
                    o.write("%s\t%d\t%d\n" % (region.chromosome, region.start, region.end))
        else:
            for region in tads:
                print("%s\t%d\t%d" % (region.chromosome, region.start, region.end))


if __name__ == '__main__':
    TADtool()
