#!/usr/bin/python2

"""Script to run SHARPener

This version of the script is parrallelised.
The script assumes that it is located in the directory

Issue:
    SDSS query does not work in parallel. Best to run parallel witout SDSS first
    and then run with a single process just SDSS

"""

import os
import string
import sys
import numpy as np
import pyaml
import json
import argparse
import imp
import multiprocessing as mp
import functools
from astropy.io import ascii
import tabulate
# import cont_src as cont_src
# import spec_ex as spec_ex
import time
# import absorption_plot as abs_pl
import glob
from PyPDF2 import PdfFileMerger
import zipfile


def run_sharpener(beam_directory_list, do_continuum_extraction, do_spectra_extraction, do_plots, do_sdss, beam_count):
    import sharpener.sharpener as sharpy
    from sharpener.sharp_modules import cont_src as cont_src
    from sharpener.sharp_modules import spec_ex as spec_ex
    from sharpener.sharp_modules import absorption_plot as abs_pl
    from sharpener.sharp_modules import sdss_match
    imp.reload(sharpy)

    """Function to run sharpener
    """
    time_start_run = time.time()

    # get process
    proc = os.getpid()
    proc_name = mp.current_process().name

    # get beam
    beam_name = os.path.basename(beam_directory_list[beam_count])

    print("(Pid {0:d}) ## Running sharpener for {1:s}".format(proc, beam_name))

    # Load parameter file
    # +++++++++++++++++++

    # print("(Pid {0:d}) ## Load parameter file".format(proc))

    parameter_file = "{0:s}/{1:s}_sharpener_default.yml".format(
        beam_directory_list[beam_count], beam_name)

    if not os.path.exists(parameter_file):
        print("(Pid {0:d}) ERROR: File {1:s} not found. Abort".format(
            proc, parameter_file))
        # sys.exit(1)

    spar = sharpy.sharpener(parameter_file)

    # print("(Pid {0:d}) ## Load parameter file ... Done".format(proc))

    # print(spar.cfg_par)

    # continue only if all files are available
    if os.path.exists(spar.cfg_par['general']['contname']) and os.path.exists(spar.cfg_par['general']['cubename']):

        # Find continuum sources
        # ++++++++++++++++++++++
        if do_continuum_extraction:
            # time_start_find = time.time()

            # print("(Pid {0:d}) ## Find continuum sources".format(proc))

            # get sources in continuum image
            sources = cont_src.find_src_imsad(spar.cfg_par)

            # print("(Pid {0:d}) ## Find continuum sources ... Done".format(proc))

            # print results
            # src_list = ascii.read(spar.cfg_par['general']
            #                       ['absdir']+'mir_src_sharpener.csv')
            # print(src_list)

        # Find continuum sources
        # ++++++++++++++++++++++
        if do_sdss:
            # time_start_find = time.time()

            # print("(Pid {0:d}) ## Find continuum sources".format(proc))

            # get sources in continuum image
            sdss_match.get_sdss_sources(spar.cfg_par)

            # print("(Pid {0:d}) ## Find continuum sources ... Done".format(proc))

        # Extract spectra
        # +++++++++++++++
        if do_spectra_extraction:
            # time_start_extract = time.time()

            # print("(Pid {0:d}) ## Extract HI spectra from cube".format(proc))

            spectra = spec_ex.abs_ex(spar.cfg_par)

            # print("##(Pid {0:d}) Extract HI spectra from cube ... Done".format(proc))

        # Plot spectra
        # ++++++++++++
        if do_plots:
            # time_start_plot = time.time()

            # print("(Pid {0:d}) ## Plotting spectra".format(proc))

            abs_pl.create_all_abs_plots(spar.cfg_par)

            # print("(Pid {0:d}) ## Plotting spectra ... Done".format(proc))

        # Merge plots:
        # ++++++++++++
        if spar.cfg_par['general']['merge_plots'] and spar.cfg_par['general']['plot_format'] == "pdf":

            # time_start_merge = time.time()

            # print("(Pid {0:d}) ## Merging plots".format(proc))

            # Merge the detailed plots
            # ++++++++++++++++++++++++
            plot_list = glob.glob(
                "{0:s}/*J*_detailed.pdf".format(spar.cfg_par['general']['plotdir']))

            if len(plot_list) != 0:

                plot_list.sort()

                # continuum plot only with radio sources
                cont_plot_name = "{0:s}{1:s}_continuum.pdf".format(
                    spar.cfg_par['general']['plotdir'], spar.cfg_par['general']['workdir'].split("/")[-2])

                # continuum plot with radio and sdss sources
                cont_sdss_plot_name = "{0:s}{1:s}_continuum_and_sdss.pdf".format(
                    spar.cfg_par['general']['plotdir'], spar.cfg_par['general']['workdir'].split("/")[-2])

                if os.path.exists(cont_sdss_plot_name):

                    plot_list.insert(0, cont_sdss_plot_name)

                if os.path.exists(cont_plot_name):

                    plot_list.insert(0, cont_plot_name)

                pdf_merger = PdfFileMerger()

                for files in plot_list:
                    pdf_merger.append(files)

                plot_name = "{0:s}{1:s}_all_plots_detailed.pdf".format(
                    spar.cfg_par['general']['plotdir'], spar.cfg_par['general']['workdir'].split("/")[-2])

                pdf_merger.write(plot_name)
            else:
                print(
                    "(Pid {0:d}) No detailed plots found. Continue".format(proc))

            # Merge compact plots
            # +++++++++++++++++++
            plot_list = glob.glob(
                "{0:s}/*J*_compact.pdf".format(spar.cfg_par['general']['plotdir']))

            if len(plot_list) != 0:

                plot_list.sort()

                # continuum plot only with radio sources
                cont_plot_name = "{0:s}{1:s}_continuum.pdf".format(
                    spar.cfg_par['general']['plotdir'], spar.cfg_par['general']['workdir'].split("/")[-2])

                # continuum plot with radio and sdss sources
                cont_sdss_plot_name = "{0:s}{1:s}_continuum_and_sdss.pdf".format(
                    spar.cfg_par['general']['plotdir'], spar.cfg_par['general']['workdir'].split("/")[-2])

                if os.path.exists(cont_sdss_plot_name):

                    plot_list.insert(0, cont_sdss_plot_name)

                if os.path.exists(cont_plot_name):

                    plot_list.insert(0, cont_plot_name)

                pdf_merger = PdfFileMerger()

                for files in plot_list:
                    pdf_merger.append(files)

                plot_name = "{0:s}{1:s}_all_plots_compact.pdf".format(
                    spar.cfg_par['general']['plotdir'], spar.cfg_par['general']['workdir'].split("/")[-2])

                pdf_merger.write(plot_name)
            else:
                print("(Pid {0:d}) No plots found. Continue".format(proc))

            # print("(Pid {0:d}) ## Merging plots ... Done ({0:.2f}s)".format(
            #     time.time() - time_start_merge))

            print("(Pid {0:d}) ## Finished SHRAPener for {1:s} ({2:.2f}s)".format(
                proc, beam_name, time.time() - time_start_run))
        else:
            print("(Pid {0:d}) ## ERROR: Could not find all files. Finished SHRAPener for {1:s} ({2:.2f}s)".format(
                proc, beam_name, time.time() - time_start_run))


if __name__ == '__main__':

    # starting time of script
    time_start = time.time()

    # Create and parse argument list
    # ++++++++++++++++++++++++++++++
    parser = argparse.ArgumentParser(
        description='Run sharpener script')

    # 1st argument: File name
    parser.add_argument("n_processes", type=int,
                        help='Set the number of processes')

    # 2nd argument: path to beams
    parser.add_argument("path_to_beams", type=str,
                        help='Set the full path to the directory where the beams are')

    # 3rd-5th argument are optional: Enable or Disable certain steps
    parser.add_argument("--do_cont", action="store_true", default=False,
                        help='Enable continuum extraction')

    parser.add_argument("--do_spectra", action="store_true", default=False,
                        help='Enable spectra extraction')

    parser.add_argument("--do_plots", action="store_true", default=False,
                        help='Enable plotting')

    parser.add_argument("--do_sdss", action="store_true", default=False,
                        help='Enable plotting')

    parser.add_argument("--do_not_all", action="store_true", default=False,
                        help='Enable plotting')

    args = parser.parse_args()

    # Set script parameters
    # +++++++++++++++++++++
    if args.n_processes is None:
        print("Please provide the number of processes. Abort")
        sys.exit(1)
    if args.path_to_beams is None:
        print("Please provide the full path to the directory where the beams are. Abort")
        sys.exit(1)

    n_processes = args.n_processes
    beam_path = args.path_to_beams
    if args.do_not_all:
        do_continuum_extraction = args.do_cont
        do_spectra_extraction = args.do_spectra
        do_plots = args.do_plots
        do_sdss = args.do_sdss
    else:
        do_continuum_extraction = True
        do_spectra_extraction = True
        do_plots = True
        do_sdss = True

    # get a list of directories
    beam_directory_list = glob.glob("{0:s}beam??".format(beam_path))
    beam_directory_list.sort()

    # get a list of beams
    beam_name_list = np.array(os.path.basename(directory)
                              for directory in beam_directory_list)

    if len(beam_directory_list) == 0:
        print("ERROR. No directories found. Aport")
        sys.exit(1)

    # index array for pool based on the number of files
    beam_count = np.arange(np.size(beam_directory_list))

    # create pool object with number of processes
    pool = mp.Pool(processes=n_processes)

    # create function iterater to provide additional arguments
    fct_partial = functools.partial(
        run_sharpener, beam_directory_list, do_continuum_extraction, do_spectra_extraction, do_plots, do_sdss)

    # create and run map
    pool.map(fct_partial, beam_count)
    pool.close()
    pool.join()

    # Create a zip file with all the plots
    # ++++++++++++++++++++++++++++++++++++
    time_start_zip = time.time()
    print("## Create zip files for plots")

    plot_list = glob.glob(
        "{0:s}beam??/plot/*all_plots*.pdf".format(beam_path))

    if len(plot_list) != 0:

        plot_list.sort()

        with zipfile.ZipFile('{0:s}all_plots.zip'.format(beam_path), 'w') as myzip:

            for plot in plot_list:
                myzip.write(plot, os.path.basename(plot))

        print("## Create zip files for plots ... Done ({0:.2f}s)".format(
            time.time() - time_start_zip))
    else:
        print("No files to zip.")

    time_start_zip = time.time()
    print("## Create zip files for source list")

    csv_list = glob.glob(
        "{0:s}beam??/abs/mir_src_sharpener.csv".format(beam_path))
    csv_sdss_list = glob.glob(
        "{0:s}beam??/abs/beam??_sdss_src.csv".format(beam_path))
    csv_sdss_radio_list = glob.glob(
        "{0:s}beam??/abs/beam??_radio_sdss_src.csv".format(beam_path))
    karma_list = glob.glob(
        "{0:s}beam??/abs/karma_src_sharpener.ann".format(beam_path))

    # do not create if there are no source at all
    if len(csv_list) != 0:

        with zipfile.ZipFile('{0:s}all_sources.zip'.format(beam_path), 'w') as myzip:

            if len(csv_list) != 0:
                csv_list.sort()
                for csv in csv_list:
                    myzip.write(csv, "{0:s}_{1:s}".format(csv.replace(
                        beam_path, "").split("/")[0], os.path.basename(csv)))

            if len(csv_sdss_list) != 0:
                csv_sdss_list.sort()
                for csv in csv_sdss_list:
                    myzip.write(csv, os.path.basename(csv))

            if len(csv_sdss_radio_list) != 0:
                csv_sdss_radio_list.sort()
                for csv in csv_sdss_radio_list:
                    myzip.write(csv, os.path.basename(csv))

            if len(karma_list) != 0:
                karma_list.sort()
                for karma in karma_list:
                    myzip.write(karma, "{0:s}_{1:s}".format(karma.replace(
                        beam_path, "").split("/")[0], os.path.basename(karma)))

        print("## Create zip files ... Done ({0:.2f}s)".format(
            time.time() - time_start_zip))
    else:
        print("No files to zip.")

    # Script finished
    print("Script finished ({0:.2f}s)".format(
        time.time() - time_start))
