#!/usr/bin/python2

"""Script to run SHARPener

TODO:
1. Correct the name of the output plot
2. make script parallel

"""

import os
import string
import sys
import numpy as np
import pyaml
import json
import argparse
from astropy.io import ascii
import sharpener.sharpener as sharpy
import tabulate
from sharpener.sharp_modules import cont_src as cont_src
from sharpener.sharp_modules import spec_ex as spec_ex
from sharpener.sharp_modules import absorption_plot as abs_pl
import time
import glob
from PyPDF2 import PdfFileMerger
import zipfile

# starting time of script
time_start = time.time()

# Create and parse argument list
# ++++++++++++++++++++++++++++++

def is_valid_file(parser, arg):
    """Validation config file"""
    if not os.path.exists(arg):
        parser.error("The file '%s' does not exist!" % arg)
    return arg

parser = argparse.ArgumentParser(
    description='Run sharpener script')

# 0th argument: Generate default config

parser.add_argument("-gd", "--generate_default",
                    default=False,
                    help='Generate a default configuration file (YAML format)')

# 1st argument: File name
parser.add_argument("-c", "--config",
                    type=lambda a: is_valid_file(parser, a),
                    default=False,
                    help='SHARPener configuration file (YAML format)')

# 2nd-4th argument are optional: Enable or Disable certain steps
parser.add_argument("--do_cont", action="store_true", default=False,
                    help='Enable continuum extraction')

parser.add_argument("--do_spectra", action="store_true", default=False,
                    help='Enable spectra extraction')

parser.add_argument("--do_plots", action="store_true", default=False,
                    help='Enable plotting')

args = parser.parse_args()

# Exit if only 
# +++++++++++++++++
if len(sys.argv) < 2:

    print("\n\t!!! no arguments were parsed from user. Try:\n\n\trun_sharpener -h\n")
    sys.exit(0)

# Make default file
# +++++++++++++++++
if args.generate_default:
    time_start_load = time.time()
    from shutil import copyfile
    import sharpener
    configfile = args.generate_default
    print("## Getting parameter file: {}".format(configfile))
    SHARPENER_PATH = os.path.dirname(os.path.dirname(os.path.abspath(sharpener.__file__)))
    copyfile('{}/{}/{}'.format(SHARPENER_PATH, 'sharpener', 'sharpener_default.yml'),
             configfile)
    print("## !!! edit parameter in your current directory before use !!!".format(time.time() - time_start_load))
    print("## Done ({0:.2f}s)".format(time.time() - time_start_load))

    sys.exit(0)

# Set script parameters
# +++++++++++++++++++++
do_continuum_extraction = args.do_cont
do_spectra_extraction = args.do_spectra
do_plots = args.do_plots

# Load parameter file
# +++++++++++++++++++
if args.config:
    time_start_load = time.time()
    print("## Load parameter file")
    spar = sharpy.sharpener(args.config)

    print("## Load parameter file ... Done ({0:.2f}s)".format(
          time.time() - time_start_load))



# Find continuum sources
# ++++++++++++++++++++++
if do_continuum_extraction or spar.cfg_par['source_finder'].get('enable',False) == True:
    
    time_start_find = time.time()
    spar.cfg_par['source_finder']['enable'] = True

    print("## Find continuum sources")

    # get sources in continuum image
    sources = cont_src.find_src_imsad(spar.cfg_par)

    print("## Find continuum sources ... Done ({0:.2f}s)".format(
        time.time() - time_start_find))

    # print results
    src_list = ascii.read(spar.cfg_par['general']
                          ['absdir']+'mir_src_sharpener.csv')
    print(src_list)

# Extract spectra
# +++++++++++++++
if do_spectra_extraction or spar.cfg_par['spec_ex'].get('enable',False) == True:
    time_start_extract = time.time()
    spar.cfg_par['spec_ex']['enable'] = True

    print("## Extract HI spectra from cube")

    spectra = spec_ex.abs_ex(spar.cfg_par)

    print("## Extract HI spectra from cube ... Done ({0:.2f}s)".format(
        time.time() - time_start_extract))

# Plot spectra
# ++++++++++++
if do_plots  or spar.cfg_par['abs_plot'].get('enable',False) == True:
    time_start_plot = time.time()
    spar.cfg_par['abs_plot']['enable'] = True

    print("## Plotting spectra")

    abs_pl.create_all_abs_plots(spar.cfg_par)

    print("## Plotting spectra ... Done ({0:.2f}s)".format(
        time.time() - time_start_plot))

    # Merge plots:
    # ++++++++++++
    if spar.cfg_par['general']['merge_plots'] and spar.cfg_par['general']['plot_format'] == "pdf":

        time_start_merge = time.time()

        print("## Merging plots")

        # Merge the detailed plots
        # ++++++++++++++++++++++++
        plot_list = glob.glob(
            "{0:s}/*J*_detailed.pdf".format(spar.cfg_par['general']['plotdir']))

        if len(plot_list) != 0:

            plot_list.sort()
            if os.path.exists("{0:s}{1:s}_continuum.pdf".format(
                spar.cfg_par['general']['plotdir'], spar.cfg_par['general']['workdir'].split("/")[-2])):
                plot_list.insert(0, "{0:s}{1:s}_continuum.pdf".format(
                spar.cfg_par['general']['plotdir'], spar.cfg_par['general']['workdir'].split("/")[-2]))

            pdf_merger = PdfFileMerger()

            for files in plot_list:
                pdf_merger.append(files)

            plot_name = "{0:s}{1:s}_all_plots_detailed.pdf".format(
                spar.cfg_par['general']['plotdir'], spar.cfg_par['general']['workdir'].split("/")[-2])

            pdf_merger.write(plot_name)
        else:
            print("No detailed plots found. Continue")

        # Merge compact plots
        # +++++++++++++++++++
        plot_list = glob.glob(
            "{0:s}/*J*_compact.pdf".format(spar.cfg_par['general']['plotdir']))

        if len(plot_list) != 0:

            plot_list.sort()
            if os.path.exists("{0:s}{1:s}_continuum.pdf".format(
                spar.cfg_par['general']['plotdir'], spar.cfg_par['general']['workdir'].split("/")[-2])):
                plot_list.insert(0, "{0:s}{1:s}_continuum.pdf".format(
                    spar.cfg_par['general']['plotdir'], spar.cfg_par['general']['workdir'].split("/")[-2]))

            pdf_merger = PdfFileMerger()

            for files in plot_list:
                pdf_merger.append(files)

            plot_name = "{0:s}{1:s}_all_plots_compact.pdf".format(
                spar.cfg_par['general']['plotdir'], spar.cfg_par['general']['workdir'].split("/")[-2])

            pdf_merger.write(plot_name)
        else:
            print("No plots found. Continue")

        print("## Merging plots ... Done ({0:.2f}s)".format(
            time.time() - time_start_merge))

        # Create a zip file with all the plots
        # ++++++++++++++++++++++++++++++++++++
        time_start_zip = time.time()
        print("## Create zip files")

        plot_list = glob.glob(
            "{0:s}/*all_plots*.pdf".format(spar.cfg_par['general']['plotdir']))

        if len(plot_list) != 0:

            with zipfile.ZipFile('NGC315_plots.zip', 'w') as myzip:

                for plot in plot_list:
                    myzip.write(plot, os.path.basename(plot))
        else:
            print("No files to zip.")

        print("## Create zip files ... Done ({0:.2f}s)".format(
            time.time() - time_start_zip))

# Script finished
print("Script finished ({0:.2f}s)".format(
    time.time() - time_start))
